scirs2_linalg/quantization/calibration/
utils.rs

1//! Utility functions for quantization calibration
2//!
3//! This module contains helper functions used by both matrix and vector
4//! calibration methods including min/max finding, histograms, KL divergence
5//! optimization, and parameter creation.
6
7use super::super::{QuantizationMethod, QuantizationParams};
8use crate::error::LinalgResult;
9use scirs2_core::ndarray::{ArrayView1, ArrayView2};
10use std::fmt::Debug;
11
12// -------------------------------------------------------------------------
13// Helper functions
14// -------------------------------------------------------------------------
15
16/// Find the minimum and maximum values in a matrix
17#[allow(dead_code)]
18pub fn find_min_max<F>(matrix: &ArrayView2<F>) -> (f32, f32)
19where
20    F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
21{
22    let mut min_val = f32::MAX;
23    let mut max_val = f32::MIN;
24
25    for &val in matrix.iter() {
26        let val_f32 = val.as_();
27        if val_f32.is_finite() {
28            min_val = min_val.min(val_f32);
29            max_val = max_val.max(val_f32);
30        }
31    }
32
33    // Handle edge cases
34    if !min_val.is_finite() || !max_val.is_finite() {
35        min_val = 0.0;
36        max_val = 1.0;
37    }
38
39    if min_val == max_val {
40        min_val -= 1.0;
41        max_val += 1.0;
42    }
43
44    (min_val, max_val)
45}
46
47/// Find the minimum and maximum values in a vector
48#[allow(dead_code)]
49pub fn find_min_max_vec<F>(vector: &ArrayView1<F>) -> (f32, f32)
50where
51    F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
52{
53    let mut min_val = f32::MAX;
54    let mut max_val = f32::MIN;
55
56    for &val in vector.iter() {
57        let val_f32 = val.as_();
58        if val_f32.is_finite() {
59            min_val = min_val.min(val_f32);
60            max_val = max_val.max(val_f32);
61        }
62    }
63
64    // Handle edge cases
65    if !min_val.is_finite() || !max_val.is_finite() {
66        min_val = 0.0;
67        max_val = 1.0;
68    }
69
70    if min_val == max_val {
71        min_val -= 1.0;
72        max_val += 1.0;
73    }
74
75    (min_val, max_val)
76}
77
78/// Create a histogram of values from a matrix
79#[allow(dead_code)]
80pub fn create_histogram<F>(
81    matrix: &ArrayView2<F>,
82    min_val: f32,
83    max_val: f32,
84    num_bins: usize,
85) -> Vec<usize>
86where
87    F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
88{
89    let mut histogram = vec![0; num_bins];
90    let bin_width = (max_val - min_val) / num_bins as f32;
91
92    if bin_width == 0.0 {
93        // All values are the same, put them all in the middle bin
94        histogram[num_bins / 2] = matrix.len();
95        return histogram;
96    }
97
98    for &_val in matrix.iter() {
99        let val_f32 = _val.as_();
100        if val_f32.is_finite() {
101            let bin_idx = ((val_f32 - min_val) / bin_width).floor() as usize;
102            let bin_idx = bin_idx.min(num_bins - 1); // Ensure we don't go out of bounds
103            histogram[bin_idx] += 1;
104        }
105    }
106
107    histogram
108}
109
110/// Create a histogram of values from a vector
111#[allow(dead_code)]
112pub fn create_histogram_vec<F>(
113    vector: &ArrayView1<F>,
114    min_val: f32,
115    max_val: f32,
116    num_bins: usize,
117) -> Vec<usize>
118where
119    F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
120{
121    let mut histogram = vec![0; num_bins];
122    let bin_width = (max_val - min_val) / num_bins as f32;
123
124    if bin_width == 0.0 {
125        // All values are the same, put them all in the middle bin
126        histogram[num_bins / 2] = vector.len();
127        return histogram;
128    }
129
130    for &_val in vector.iter() {
131        let val_f32 = _val.as_();
132        if val_f32.is_finite() {
133            let bin_idx = ((val_f32 - min_val) / bin_width).floor() as usize;
134            let bin_idx = bin_idx.min(num_bins - 1); // Ensure we don't go out of bounds
135            histogram[bin_idx] += 1;
136        }
137    }
138
139    histogram
140}
141
142/// Optimize thresholds using KL divergence
143#[allow(dead_code)]
144pub fn optimize_thresholds_kl_divergence(
145    histogram: &[usize],
146    min_val: f32,
147    max_val: f32,
148    bits: u8,
149    symmetric: bool,
150) -> (f32, f32) {
151    let num_bins = histogram.len();
152    let bin_width = (max_val - min_val) / num_bins as f32;
153
154    // Convert histogram to probability distribution
155    let total_count = histogram.iter().sum::<usize>() as f32;
156    let distribution: Vec<f32> = histogram
157        .iter()
158        .map(|&count| count as f32 / total_count)
159        .collect();
160
161    // Number of quantization levels
162    let levels = if symmetric {
163        (1 << (bits - 1)) as usize // For signed integers
164    } else {
165        (1 << bits) as usize // For unsigned integers
166    };
167
168    // For symmetric quantization, we want to find the optimal abs_max
169    if symmetric {
170        // Search for the optimal abs_max that minimizes KL divergence
171        let mut best_abs_max = max_val.abs().max(min_val.abs());
172        let mut min_kl = f32::MAX;
173
174        // Try different abs_max values
175        let step = (best_abs_max / 20.0).max(1e-6);
176        for i in 0..40 {
177            let abs_max = best_abs_max - 20.0 * step + i as f32 * step;
178            if abs_max <= 0.0 {
179                continue;
180            }
181
182            // Calculate quantization step
183            let quantization_step = abs_max / (levels - 1) as f32;
184
185            // Calculate KL divergence for this abs_max
186            let kl = calculate_kl_divergence_symmetric(
187                &distribution,
188                min_val,
189                max_val,
190                bin_width,
191                abs_max,
192                quantization_step,
193            );
194
195            if kl < min_kl {
196                min_kl = kl;
197                best_abs_max = abs_max;
198            }
199        }
200
201        // Return symmetric range
202        (-best_abs_max, best_abs_max)
203    } else {
204        // For asymmetric quantization, find the best min and max
205        let mut best_min = min_val;
206        let mut best_max = max_val;
207        let mut min_kl = f32::MAX;
208
209        // Grid search for optimal min/max
210        let min_step = (max_val - min_val) / 40.0;
211        let max_step = min_step;
212
213        for i in 0..10 {
214            let trial_min = min_val + i as f32 * min_step;
215
216            for j in 0..10 {
217                let trial_max = max_val - j as f32 * max_step;
218
219                if trial_min >= trial_max {
220                    continue;
221                }
222
223                // Calculate quantization step
224                let quantization_step = (trial_max - trial_min) / (levels - 1) as f32;
225
226                // Calculate KL divergence for this range
227                let kl = calculate_kl_divergence_asymmetric(
228                    &distribution,
229                    min_val,
230                    max_val,
231                    bin_width,
232                    trial_min,
233                    trial_max,
234                    quantization_step,
235                );
236
237                if kl < min_kl {
238                    min_kl = kl;
239                    best_min = trial_min;
240                    best_max = trial_max;
241                }
242            }
243        }
244
245        (best_min, best_max)
246    }
247}
248
249/// Calculate KL divergence for symmetric quantization
250#[allow(dead_code)]
251fn calculate_kl_divergence_symmetric(
252    distribution: &[f32],
253    min_val: f32,
254    _max_val: f32,
255    bin_width: f32,
256    abs_max: f32,
257    quantization_step: f32,
258) -> f32 {
259    let num_bins = distribution.len();
260
261    // Create quantized probability distribution
262    let mut quantized_dist = vec![0.0; num_bins];
263
264    for (bin_idx, &prob) in distribution.iter().enumerate() {
265        // Original value at the center of this bin
266        let orig_val = min_val + (bin_idx as f32 + 0.5) * bin_width;
267
268        // Quantize the value
269        let quantized_val = if orig_val > abs_max {
270            abs_max
271        } else if orig_val < -abs_max {
272            -abs_max
273        } else {
274            // Round to nearest quantization step
275            (orig_val / quantization_step).round() * quantization_step
276        };
277
278        // Map back to bin index
279        let new_bin_idx = ((quantized_val - min_val) / bin_width).floor() as i32;
280
281        if new_bin_idx >= 0 && new_bin_idx < num_bins as i32 {
282            quantized_dist[new_bin_idx as usize] += prob;
283        }
284    }
285
286    // Calculate KL divergence: sum(p * log(p / q))
287    let mut kl = 0.0;
288    for (i, &p) in distribution.iter().enumerate() {
289        if p > 0.0 {
290            let q = quantized_dist[i].max(1e-10); // Avoid division by zero
291            kl += p * (p / q).ln();
292        }
293    }
294
295    kl
296}
297
298/// Calculate KL divergence for asymmetric quantization
299#[allow(dead_code)]
300fn calculate_kl_divergence_asymmetric(
301    distribution: &[f32],
302    min_val: f32,
303    _max_val: f32,
304    bin_width: f32,
305    quant_min: f32,
306    quant_max: f32,
307    quantization_step: f32,
308) -> f32 {
309    let num_bins = distribution.len();
310
311    // Create quantized probability distribution
312    let mut quantized_dist = vec![0.0; num_bins];
313
314    for (bin_idx, &prob) in distribution.iter().enumerate() {
315        // Original value at the center of this bin
316        let orig_val = min_val + (bin_idx as f32 + 0.5) * bin_width;
317
318        // Quantize the value
319        let quantized_val = if orig_val > quant_max {
320            quant_max
321        } else if orig_val < quant_min {
322            quant_min
323        } else {
324            // Round to nearest quantization step
325            let steps = ((orig_val - quant_min) / quantization_step).round();
326            quant_min + steps * quantization_step
327        };
328
329        // Map back to bin index
330        let new_bin_idx = ((quantized_val - min_val) / bin_width).floor() as i32;
331
332        if new_bin_idx >= 0 && new_bin_idx < num_bins as i32 {
333            quantized_dist[new_bin_idx as usize] += prob;
334        }
335    }
336
337    // Calculate KL divergence: sum(p * log(p / q))
338    let mut kl = 0.0;
339    for (i, &p) in distribution.iter().enumerate() {
340        if p > 0.0 {
341            let q = quantized_dist[i].max(1e-10); // Avoid division by zero
342            kl += p * (p / q).ln();
343        }
344    }
345
346    kl
347}
348
349/// Optimize symmetric scale factor using MSE
350#[allow(dead_code)]
351pub fn optimize_symmetric_scale<F>(matrix: &ArrayView2<F>, bits: u8, basescale: f32) -> f32
352where
353    F: scirs2_core::numeric::Float
354        + Debug
355        + scirs2_core::numeric::AsPrimitive<f32>
356        + scirs2_core::numeric::FromPrimitive,
357    f32: scirs2_core::numeric::AsPrimitive<F>,
358{
359    let num_trials = 20;
360    let scales: Vec<f32> = (0..num_trials)
361        .map(|i| {
362            let factor = 0.5 + 1.5 * (i as f32 / (num_trials - 1) as f32);
363            basescale * factor
364        })
365        .collect();
366
367    let mut best_scale = basescale;
368    let mut min_mse = f32::MAX;
369
370    // Test each scale factor
371    for &scale in &scales {
372        // Create temporary quantization parameters
373        let abs_max = matrix
374            .mapv(|x| x.as_().abs())
375            .fold(0.0, |a: f32, &b| a.max(b));
376        let params = QuantizationParams {
377            bits,
378            scale,
379            zero_point: 0,
380            min_val: -abs_max,
381            max_val: abs_max,
382            method: if bits == 4 {
383                QuantizationMethod::Int4
384            } else {
385                QuantizationMethod::Symmetric
386            },
387            data_type: determine_data_type(bits),
388            channel_scales: None,
389            channel_zero_points: None,
390        };
391
392        // Manually simulate quantization and dequantization for F type
393        let matrix_f32 = matrix.mapv(|x| x.as_());
394        let current_scale = params.scale;
395        let dequantized = matrix_f32.mapv(|x| {
396            let quantized = (x / scale)
397                .round()
398                .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32);
399            quantized * current_scale
400        });
401
402        // Calculate MSE
403        let mse = (&matrix_f32 - &dequantized).mapv(|x| x * x).sum() / matrix.len() as f32;
404
405        if mse < min_mse {
406            min_mse = mse;
407            best_scale = scale;
408        }
409    }
410
411    best_scale
412}
413
414/// Optimize symmetric scale factor for vectors using MSE
415#[allow(dead_code)]
416pub fn optimize_symmetric_scale_vec<F>(_vector: &ArrayView1<F>, bits: u8, basescale: f32) -> f32
417where
418    F: scirs2_core::numeric::Float
419        + Debug
420        + scirs2_core::numeric::AsPrimitive<f32>
421        + scirs2_core::numeric::FromPrimitive,
422    f32: scirs2_core::numeric::AsPrimitive<F>,
423{
424    let num_trials = 20;
425    let scales: Vec<f32> = (0..num_trials)
426        .map(|i| {
427            let factor = 0.5 + 1.5 * (i as f32 / (num_trials - 1) as f32);
428            basescale * factor
429        })
430        .collect();
431
432    let mut best_scale = basescale;
433    let mut min_mse = f32::MAX;
434
435    // Test each scale factor
436    for &scale in &scales {
437        // Create temporary QuantizationParams
438        let abs_max = _vector
439            .mapv(|x| x.as_().abs())
440            .fold(0.0, |a: f32, &b| a.max(b));
441        let params = QuantizationParams {
442            bits,
443            scale,
444            zero_point: 0,
445            min_val: -abs_max,
446            max_val: abs_max,
447            method: if bits == 4 {
448                QuantizationMethod::Int4
449            } else {
450                QuantizationMethod::Symmetric
451            },
452            data_type: determine_data_type(bits),
453            channel_scales: None,
454            channel_zero_points: None,
455        };
456
457        // Manually simulate quantization and dequantization for F type
458        let vector_f32 = _vector.mapv(|x| x.as_());
459        let current_scale = params.scale;
460        let dequantized = vector_f32.mapv(|x| {
461            let quantized = (x / scale)
462                .round()
463                .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32);
464            quantized * current_scale
465        });
466
467        // Calculate MSE
468        let mse = (&vector_f32 - &dequantized).mapv(|x| x * x).sum() / _vector.len() as f32;
469
470        if mse < min_mse {
471            min_mse = mse;
472            best_scale = scale;
473        }
474    }
475
476    best_scale
477}
478
479/// Optimize affine quantization parameters (scale and zero point) using MSE
480#[allow(dead_code)]
481pub fn optimize_affine_params<F>(
482    matrix: &ArrayView2<F>,
483    bits: u8,
484    base_scale: f32,
485    base_zero_point: i32,
486) -> (f32, i32)
487where
488    F: scirs2_core::numeric::Float
489        + Debug
490        + scirs2_core::numeric::AsPrimitive<f32>
491        + scirs2_core::numeric::FromPrimitive,
492    f32: scirs2_core::numeric::AsPrimitive<F>,
493{
494    let num_scale_trials = 10;
495    let num_zp_trials = 5;
496
497    let scales: Vec<f32> = (0..num_scale_trials)
498        .map(|i| {
499            let factor = 0.8 + 0.4 * (i as f32 / (num_scale_trials - 1) as f32);
500            base_scale * factor
501        })
502        .collect();
503
504    let zero_points: Vec<i32> = (0..num_zp_trials)
505        .map(|i| {
506            let offset = -2 + i;
507            base_zero_point + offset
508        })
509        .collect();
510
511    let mut best_scale = base_scale;
512    let mut best_zero_point = base_zero_point;
513    let mut min_mse = f32::MAX;
514
515    // Test each combination of scale and zero point
516    for &_scale in &scales {
517        for &zero_point in &zero_points {
518            // Create temporary QuantizationParams
519            let mut params = QuantizationParams {
520                bits,
521                scale: _scale,
522                zero_point,
523                min_val: 0.0, // Will be set by quantize_matrix
524                max_val: 0.0, // Will be set by quantize_matrix
525                method: QuantizationMethod::Affine,
526                data_type: determine_data_type(bits),
527                channel_scales: None,
528                channel_zero_points: None,
529            };
530
531            // Manually simulate affine quantization and dequantization for F type
532            let matrix_f32 = matrix.mapv(|x| x.as_());
533            let scale = params.scale;
534            let zero_point = params.zero_point;
535
536            // Find min/max values for the matrix
537            let mut min_val = f32::MAX;
538            let mut max_val = f32::MIN;
539            for &val in matrix_f32.iter() {
540                if val.is_finite() {
541                    min_val = min_val.min(val);
542                    max_val = max_val.max(val);
543                }
544            }
545            params.min_val = min_val;
546            params.max_val = max_val;
547
548            let dequantized = matrix_f32.mapv(|x| {
549                let quantized = ((x / scale) + zero_point as f32)
550                    .round()
551                    .clamp(0.0, ((1 << bits) - 1) as f32);
552                (quantized - zero_point as f32) * scale
553            });
554
555            // Calculate MSE
556            let mse = (&matrix_f32 - &dequantized).mapv(|x| x * x).sum() / matrix.len() as f32;
557
558            if mse < min_mse {
559                min_mse = mse;
560                best_scale = scale;
561                best_zero_point = zero_point;
562            }
563        }
564    }
565
566    (best_scale, best_zero_point)
567}
568
569/// Optimize affine quantization parameters for vectors using MSE
570#[allow(dead_code)]
571pub fn optimize_affine_params_vec<F>(
572    vector: &ArrayView1<F>,
573    bits: u8,
574    base_scale: f32,
575    base_zero_point: i32,
576) -> (f32, i32)
577where
578    F: scirs2_core::numeric::Float
579        + Debug
580        + scirs2_core::numeric::AsPrimitive<f32>
581        + scirs2_core::numeric::FromPrimitive,
582    f32: scirs2_core::numeric::AsPrimitive<F>,
583{
584    let num_scale_trials = 10;
585    let num_zp_trials = 5;
586
587    let scales: Vec<f32> = (0..num_scale_trials)
588        .map(|i| {
589            let factor = 0.8 + 0.4 * (i as f32 / (num_scale_trials - 1) as f32);
590            base_scale * factor
591        })
592        .collect();
593
594    let zero_points: Vec<i32> = (0..num_zp_trials)
595        .map(|i| {
596            let offset = -2 + i;
597            base_zero_point + offset
598        })
599        .collect();
600
601    let mut best_scale = base_scale;
602    let mut best_zero_point = base_zero_point;
603    let mut min_mse = f32::MAX;
604
605    // Test each combination of scale and zero point
606    for &_scale in &scales {
607        for &zero_point in &zero_points {
608            // Create temporary QuantizationParams
609            let mut params = QuantizationParams {
610                bits,
611                scale: _scale,
612                zero_point,
613                min_val: 0.0, // Will be set by quantize_vector
614                max_val: 0.0, // Will be set by quantize_vector
615                method: QuantizationMethod::Affine,
616                data_type: determine_data_type(bits),
617                channel_scales: None,
618                channel_zero_points: None,
619            };
620
621            // Manually simulate affine quantization and dequantization for F type
622            let vector_f32 = vector.mapv(|x| x.as_());
623            let scale = params.scale;
624            let zero_point = params.zero_point;
625
626            // Find min/max values for the vector
627            let mut min_val = f32::MAX;
628            let mut max_val = f32::MIN;
629            for &val in vector_f32.iter() {
630                if val.is_finite() {
631                    min_val = min_val.min(val);
632                    max_val = max_val.max(val);
633                }
634            }
635            params.min_val = min_val;
636            params.max_val = max_val;
637
638            let dequantized = vector_f32.mapv(|x| {
639                let quantized = ((x / scale) + zero_point as f32)
640                    .round()
641                    .clamp(0.0, ((1 << bits) - 1) as f32);
642                (quantized - zero_point as f32) * scale
643            });
644
645            // Calculate MSE
646            let mse = (&vector_f32 - &dequantized).mapv(|x| x * x).sum() / vector.len() as f32;
647
648            if mse < min_mse {
649                min_mse = mse;
650                best_scale = scale;
651                best_zero_point = zero_point;
652            }
653        }
654    }
655
656    (best_scale, best_zero_point)
657}
658
659/// Create QuantizationParams from a min-max range
660#[allow(dead_code)]
661pub fn create_params_from_range(
662    bits: u8,
663    min_val: f32,
664    max_val: f32,
665    symmetric: bool,
666) -> LinalgResult<QuantizationParams> {
667    let (method, scale, zero_point) = if symmetric {
668        let abs_max = max_val.abs().max(min_val.abs());
669        let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
670        (QuantizationMethod::Symmetric, scale, 0)
671    } else {
672        let method = QuantizationMethod::Affine;
673        let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
674        let zero_point = (-min_val / scale).round() as i32;
675        (method, scale, zero_point)
676    };
677
678    Ok(QuantizationParams {
679        bits,
680        scale,
681        zero_point,
682        min_val,
683        max_val,
684        method,
685        data_type: determine_data_type(bits),
686        channel_scales: None,
687        channel_zero_points: None,
688    })
689}
690
691/// Determine the appropriate data type based on bit width
692#[allow(dead_code)]
693pub fn determine_data_type(bits: u8) -> super::super::QuantizedDataType {
694    use super::super::QuantizedDataType;
695
696    match bits {
697        4 => QuantizedDataType::Int4,     // Default to Int4 for 4-bit
698        8 => QuantizedDataType::Int8,     // Default to Int8 for 8-bit
699        16 => QuantizedDataType::Float16, // Default to Float16 for 16-bit
700        _ => QuantizedDataType::Int8,     // Default to Int8 for other cases
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use approx::assert_relative_eq;
708    use scirs2_core::ndarray::array;
709
710    #[test]
711    fn test_find_min_max() {
712        let matrix = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
713        let (min_val, max_val) = find_min_max(&matrix.view());
714        assert_eq!(min_val, 1.0);
715        assert_eq!(max_val, 9.0);
716    }
717
718    #[test]
719    fn test_find_min_max_vec() {
720        let vector = array![1.0f32, 2.0, 3.0, 4.0, 5.0];
721        let (min_val, max_val) = find_min_max_vec(&vector.view());
722        assert_eq!(min_val, 1.0);
723        assert_eq!(max_val, 5.0);
724    }
725
726    #[test]
727    fn test_create_histogram() {
728        let matrix = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
729        let histogram = create_histogram(&matrix.view(), 1.0, 6.0, 5);
730
731        // Each value should be in a different bin
732        assert_eq!(histogram.iter().sum::<usize>(), 6); // Total count
733        assert!(histogram.iter().all(|&count| count <= 2)); // No bin has more than 2 values
734    }
735
736    #[test]
737    fn test_create_params_from_range() {
738        // Test symmetric quantization
739        let params = create_params_from_range(8, -5.0, 5.0, true).unwrap();
740        assert_eq!(params.method, QuantizationMethod::Symmetric);
741        assert_eq!(params.zero_point, 0);
742        assert_relative_eq!(params.scale, 5.0 / 127.0, epsilon = 1e-6);
743
744        // Test affine quantization
745        let params = create_params_from_range(8, 1.0, 9.0, false).unwrap();
746        assert_eq!(params.method, QuantizationMethod::Affine);
747        assert_relative_eq!(params.scale, 8.0 / 255.0, epsilon = 1e-6);
748        assert_eq!(params.zero_point, (-1.0 / params.scale).round() as i32);
749    }
750}