Skip to main content

torsh_quantization/
observers.rs

1//! Observer implementations for quantization parameter calibration
2//!
3//! This module provides various observer types for collecting statistics from tensors
4//! during the calibration phase of quantization. Observers track tensor distributions
5//! and calculate optimal quantization parameters.
6//!
7//! # Features
8//!
9//! - **MinMax Observer**: Simple min/max range tracking
10//! - **MovingAverage Observer**: Exponential moving average of ranges
11//! - **Histogram Observer**: Distribution-based quantization with outlier removal
12//! - **Percentile Observer**: Percentile-based range estimation
13//! - **Parallel Processing**: Optimized for large tensors using Rayon
14//! - **Outlier Detection**: IQR-based outlier detection and removal
15//! - **Memory Management**: Efficient memory usage for large datasets
16
17use crate::config::ObserverType;
18
19#[cfg(feature = "std")]
20use std::collections::HashMap;
21
22#[cfg(not(feature = "std"))]
23extern crate alloc;
24
25#[cfg(not(feature = "std"))]
26use alloc::{collections::BTreeMap as HashMap, string::String, vec::Vec};
27
28use torsh_core::{
29    dtype::DType,
30    error::{Result as TorshResult, TorshError},
31};
32use torsh_tensor::Tensor;
33
34/// Observer for tracking tensor statistics during calibration
35#[derive(Debug)]
36pub struct Observer {
37    observer_type: ObserverType,
38    min_val: f32,
39    max_val: f32,
40    num_batches: usize,
41    // For moving average observer
42    #[allow(dead_code)]
43    avg_min: f32,
44    #[allow(dead_code)]
45    avg_max: f32,
46    // For histogram observer
47    histogram: Vec<usize>,
48    hist_min: f32,
49    hist_max: f32,
50    num_bins: usize,
51    // For percentile observer
52    values: Vec<f32>,
53    percentile: f32,
54}
55
56impl Observer {
57    /// Create a new observer
58    pub fn new(observer_type: ObserverType) -> Self {
59        Self {
60            observer_type,
61            min_val: f32::INFINITY,
62            max_val: f32::NEG_INFINITY,
63            num_batches: 0,
64            avg_min: 0.0,
65            avg_max: 0.0,
66            histogram: vec![0; 256], // Default 256 bins
67            hist_min: f32::INFINITY,
68            hist_max: f32::NEG_INFINITY,
69            num_bins: 256,
70            values: Vec::new(),
71            percentile: 99.99, // Default percentile
72        }
73    }
74
75    /// Create a new histogram observer with specified number of bins
76    pub fn new_histogram(num_bins: usize) -> Self {
77        Self {
78            observer_type: ObserverType::Histogram,
79            min_val: f32::INFINITY,
80            max_val: f32::NEG_INFINITY,
81            num_batches: 0,
82            avg_min: 0.0,
83            avg_max: 0.0,
84            histogram: vec![0; num_bins],
85            hist_min: f32::INFINITY,
86            hist_max: f32::NEG_INFINITY,
87            num_bins,
88            values: Vec::new(),
89            percentile: 99.99,
90        }
91    }
92
93    /// Create a new percentile observer with specified percentile
94    pub fn new_percentile(percentile: f32) -> Self {
95        Self {
96            observer_type: ObserverType::Percentile,
97            min_val: f32::INFINITY,
98            max_val: f32::NEG_INFINITY,
99            num_batches: 0,
100            avg_min: 0.0,
101            avg_max: 0.0,
102            histogram: Vec::new(),
103            hist_min: f32::INFINITY,
104            hist_max: f32::NEG_INFINITY,
105            num_bins: 0,
106            values: Vec::new(),
107            percentile,
108        }
109    }
110
111    /// Update observer with new tensor (optimized with parallel processing)
112    pub fn update(&mut self, tensor: &Tensor) -> TorshResult<()> {
113        let data = tensor.data()?;
114
115        // Always count as a batch, even if data is empty
116        self.num_batches += 1;
117
118        if data.is_empty() {
119            return Ok(());
120        }
121
122        // Validate data for NaN/infinity
123        if data.iter().any(|&x| !x.is_finite()) {
124            return Err(TorshError::InvalidArgument(
125                "Tensor contains non-finite values (NaN or infinity)".to_string(),
126            ));
127        }
128
129        // Use parallel processing for large tensors
130        let (batch_min, batch_max) = if data.len() > 10000 {
131            #[cfg(feature = "std")]
132            {
133                use scirs2_core::parallel_ops::*;
134                data.par_iter().map(|&x| (x, x)).reduce(
135                    || (f32::INFINITY, f32::NEG_INFINITY),
136                    |(min1, max1), (min2, max2)| (min1.min(min2), max1.max(max2)),
137                )
138            }
139            #[cfg(not(feature = "std"))]
140            {
141                data.iter()
142                    .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), &val| {
143                        (min.min(val), max.max(val))
144                    })
145            }
146        } else {
147            data.iter()
148                .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), &val| {
149                    (min.min(val), max.max(val))
150                })
151        };
152
153        match self.observer_type {
154            ObserverType::MinMax => {
155                self.min_val = self.min_val.min(batch_min);
156                self.max_val = self.max_val.max(batch_max);
157            }
158            ObserverType::MovingAverage => {
159                if self.num_batches == 0 {
160                    self.min_val = batch_min;
161                    self.max_val = batch_max;
162                    self.avg_min = batch_min;
163                    self.avg_max = batch_max;
164                } else {
165                    let alpha = 0.01; // Moving average factor
166                    self.avg_min = alpha * batch_min + (1.0 - alpha) * self.avg_min;
167                    self.avg_max = alpha * batch_max + (1.0 - alpha) * self.avg_max;
168                    // Keep global min/max for reference
169                    self.min_val = self.min_val.min(batch_min);
170                    self.max_val = self.max_val.max(batch_max);
171                }
172            }
173            ObserverType::Histogram => {
174                // Update global min/max first
175                self.min_val = self.min_val.min(batch_min);
176                self.max_val = self.max_val.max(batch_max);
177
178                // Update histogram range if this is the first batch
179                if self.num_batches == 0 {
180                    self.hist_min = batch_min;
181                    self.hist_max = batch_max;
182                } else {
183                    self.hist_min = self.hist_min.min(batch_min);
184                    self.hist_max = self.hist_max.max(batch_max);
185                }
186
187                // Add values to histogram with improved binning
188                if data.len() > 5000 {
189                    // Use parallel histogram update for large tensors
190                    #[cfg(feature = "std")]
191                    {
192                        use scirs2_core::parallel_ops::*;
193                        let local_histograms: Vec<Vec<usize>> = data
194                            .par_chunks(1000)
195                            .map(|chunk| {
196                                let mut local_hist = vec![0; self.num_bins];
197                                for &value in chunk {
198                                    let bin_idx = self.value_to_bin_index(value);
199                                    if bin_idx < local_hist.len() {
200                                        local_hist[bin_idx] += 1;
201                                    }
202                                }
203                                local_hist
204                            })
205                            .collect();
206
207                        // Merge local histograms
208                        for local_hist in local_histograms {
209                            for (i, count) in local_hist.iter().enumerate() {
210                                self.histogram[i] += count;
211                            }
212                        }
213                    }
214                    #[cfg(not(feature = "std"))]
215                    {
216                        for &value in data.iter() {
217                            let bin_idx = self.value_to_bin_index(value);
218                            if bin_idx < self.histogram.len() {
219                                self.histogram[bin_idx] += 1;
220                            }
221                        }
222                    }
223                } else {
224                    for &value in data.iter() {
225                        let bin_idx = self.value_to_bin_index(value);
226                        if bin_idx < self.histogram.len() {
227                            self.histogram[bin_idx] += 1;
228                        }
229                    }
230                }
231            }
232            ObserverType::Percentile => {
233                // Update global min/max
234                self.min_val = self.min_val.min(batch_min);
235                self.max_val = self.max_val.max(batch_max);
236
237                // Limit memory usage for percentile calculation
238                if self.values.len() + data.len() > 100_000 {
239                    // Sample the data to avoid memory explosion
240                    let sample_rate = 100_000.0 / (self.values.len() + data.len()) as f32;
241                    let sampled_data: Vec<f32> = data
242                        .iter()
243                        .enumerate()
244                        .filter(|(i, _)| (*i as f32 * sample_rate) % 1.0 < sample_rate)
245                        .map(|(_, &val)| val)
246                        .collect();
247                    self.values.extend(sampled_data);
248                } else {
249                    self.values.extend(data.iter().cloned());
250                }
251            }
252            _ => {
253                // For other observer types, fall back to min-max
254                self.min_val = self.min_val.min(batch_min);
255                self.max_val = self.max_val.max(batch_max);
256            }
257        }
258
259        Ok(())
260    }
261
262    /// Calculate quantization parameters from observed statistics
263    pub fn calculate_qparams(&self, dtype: DType) -> TorshResult<(f32, i32)> {
264        let (qmin, qmax) = match dtype {
265            DType::I8 => (-128, 127),
266            DType::U8 => (0, 255),
267            _ => {
268                return Err(TorshError::InvalidArgument(
269                    "Unsupported quantization dtype".to_string(),
270                ))
271            }
272        };
273
274        // Use observer-specific range calculation
275        let (min_val, max_val) = match self.observer_type {
276            ObserverType::Histogram => {
277                if !self.histogram.is_empty() {
278                    self.calculate_histogram_range()
279                } else {
280                    (self.min_val.min(0.0), self.max_val.max(0.0))
281                }
282            }
283            ObserverType::Percentile => {
284                if !self.values.is_empty() {
285                    self.calculate_percentile_range()
286                } else {
287                    (self.min_val.min(0.0), self.max_val.max(0.0))
288                }
289            }
290            _ => (self.min_val.min(0.0), self.max_val.max(0.0)),
291        };
292
293        let scale = (max_val - min_val) / (qmax - qmin) as f32;
294        let scale = if scale == 0.0 { 1.0 } else { scale };
295
296        let zero_point = (qmin as f32 - min_val / scale)
297            .round()
298            .max(qmin as f32)
299            .min(qmax as f32) as i32;
300
301        Ok((scale, zero_point))
302    }
303
304    /// Convert a value to histogram bin index with improved stability
305    fn value_to_bin_index(&self, value: f32) -> usize {
306        // Use hist_min/hist_max for more accurate binning
307        let range_min = if self.hist_min.is_finite() {
308            self.hist_min
309        } else {
310            self.min_val
311        };
312        let range_max = if self.hist_max.is_finite() {
313            self.hist_max
314        } else {
315            self.max_val
316        };
317
318        if range_max <= range_min || !value.is_finite() {
319            return 0;
320        }
321
322        let ratio = ((value - range_min) / (range_max - range_min)).clamp(0.0, 1.0);
323        let idx = (ratio * self.num_bins as f32).floor() as usize;
324        idx.min(self.num_bins - 1)
325    }
326
327    /// Calculate optimal range from histogram with enhanced outlier removal
328    fn calculate_histogram_range(&self) -> (f32, f32) {
329        if self.histogram.is_empty() || self.num_bins == 0 {
330            return (self.min_val, self.max_val);
331        }
332
333        let total_samples: usize = self.histogram.iter().sum();
334        if total_samples == 0 {
335            return (self.min_val, self.max_val);
336        }
337
338        // Use adaptive threshold based on data distribution
339        let outlier_threshold = if total_samples > 10000 {
340            0.001 // 0.1% for large datasets
341        } else if total_samples > 1000 {
342            0.005 // 0.5% for medium datasets
343        } else {
344            0.01 // 1% for small datasets
345        };
346
347        let threshold_count = (total_samples as f32 * outlier_threshold) as usize;
348        let mut cumsum = 0;
349        let mut start_bin = 0;
350        let mut end_bin = self.num_bins - 1;
351
352        // Find start bin (skip outliers from the beginning)
353        for (i, &count) in self.histogram.iter().enumerate() {
354            cumsum += count;
355            if cumsum > threshold_count {
356                start_bin = i;
357                break;
358            }
359        }
360
361        // Find end bin (skip outliers from the end)
362        cumsum = 0;
363        for (i, &count) in self.histogram.iter().enumerate().rev() {
364            cumsum += count;
365            if cumsum > threshold_count {
366                end_bin = i;
367                break;
368            }
369        }
370
371        // Ensure we have a valid range
372        if start_bin >= end_bin {
373            return (self.min_val, self.max_val);
374        }
375
376        let range_min = if self.hist_min.is_finite() {
377            self.hist_min
378        } else {
379            self.min_val
380        };
381        let range_max = if self.hist_max.is_finite() {
382            self.hist_max
383        } else {
384            self.max_val
385        };
386
387        if range_max <= range_min {
388            return (self.min_val, self.max_val);
389        }
390
391        let bin_width = (range_max - range_min) / self.num_bins as f32;
392        let min_val = range_min + start_bin as f32 * bin_width;
393        let max_val = range_min + (end_bin + 1) as f32 * bin_width;
394
395        // Ensure the calculated range is valid
396        if min_val >= max_val {
397            (self.min_val, self.max_val)
398        } else {
399            (min_val.max(self.min_val), max_val.min(self.max_val))
400        }
401    }
402
403    /// Calculate percentile-based range
404    fn calculate_percentile_range(&self) -> (f32, f32) {
405        if self.values.is_empty() {
406            return (self.min_val, self.max_val);
407        }
408
409        let mut sorted_values = self.values.clone();
410        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
411
412        let n = sorted_values.len();
413        let lower_percentile = 100.0 - self.percentile;
414        let upper_percentile = self.percentile;
415
416        let lower_idx = ((lower_percentile / 100.0) * n as f32) as usize;
417        let upper_idx = ((upper_percentile / 100.0) * n as f32) as usize;
418
419        let lower_idx = lower_idx.min(n - 1);
420        let upper_idx = upper_idx.min(n - 1);
421
422        (sorted_values[lower_idx], sorted_values[upper_idx])
423    }
424
425    /// Detect and remove outliers using IQR method
426    pub fn detect_outliers(&self, data: &[f32], factor: f32) -> (Vec<f32>, usize) {
427        if data.is_empty() {
428            return (Vec::new(), 0);
429        }
430
431        let mut sorted_data = data.to_vec();
432        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
433
434        let n = sorted_data.len();
435
436        // Use proper percentile calculation for quartiles
437        let q1 = if n >= 4 {
438            let idx = (n as f32 * 0.25) as usize;
439            if idx > 0 {
440                sorted_data[idx.min(n - 1)]
441            } else {
442                sorted_data[0]
443            }
444        } else {
445            sorted_data[0]
446        };
447
448        let q3 = if n >= 4 {
449            let idx = (n as f32 * 0.75) as usize;
450            sorted_data[idx.min(n - 1)]
451        } else {
452            sorted_data[n - 1]
453        };
454
455        let iqr = q3 - q1;
456
457        // If IQR is too small, use a more conservative approach
458        if iqr < 1e-6 {
459            return (sorted_data, 0);
460        }
461
462        let lower_bound = q1 - factor * iqr;
463        let upper_bound = q3 + factor * iqr;
464
465        let original_len = data.len();
466        let cleaned_data: Vec<f32> = data
467            .iter()
468            .filter(|&&x| x >= lower_bound && x <= upper_bound)
469            .cloned()
470            .collect();
471
472        let outliers_removed = original_len - cleaned_data.len();
473
474        (cleaned_data, outliers_removed)
475    }
476
477    /// Get comprehensive statistics from the observer
478    pub fn get_statistics(&self) -> HashMap<String, f32> {
479        let mut stats = HashMap::new();
480
481        stats.insert("min_val".to_string(), self.min_val);
482        stats.insert("max_val".to_string(), self.max_val);
483        stats.insert("range".to_string(), self.max_val - self.min_val);
484        stats.insert("num_batches".to_string(), self.num_batches as f32);
485
486        match self.observer_type {
487            ObserverType::Histogram => {
488                stats.insert("num_bins".to_string(), self.num_bins as f32);
489                stats.insert(
490                    "total_samples".to_string(),
491                    self.histogram.iter().sum::<usize>() as f32,
492                );
493                if !self.histogram.is_empty() {
494                    let max_bin_count = *self.histogram.iter().max().unwrap_or(&0);
495                    stats.insert("max_bin_count".to_string(), max_bin_count as f32);
496                }
497            }
498            ObserverType::Percentile => {
499                stats.insert("total_values".to_string(), self.values.len() as f32);
500                stats.insert("percentile".to_string(), self.percentile);
501            }
502            _ => {}
503        }
504
505        stats
506    }
507
508    /// Get the observer type
509    pub fn observer_type(&self) -> ObserverType {
510        self.observer_type
511    }
512
513    /// Get the current min/max values
514    pub fn get_min_max(&self) -> (f32, f32) {
515        (self.min_val, self.max_val)
516    }
517
518    /// Get number of processed batches
519    pub fn num_batches(&self) -> usize {
520        self.num_batches
521    }
522
523    /// Reset the observer state
524    pub fn reset(&mut self) {
525        self.min_val = f32::INFINITY;
526        self.max_val = f32::NEG_INFINITY;
527        self.num_batches = 0;
528        self.avg_min = 0.0;
529        self.avg_max = 0.0;
530        self.hist_min = f32::INFINITY;
531        self.hist_max = f32::NEG_INFINITY;
532        self.histogram.iter_mut().for_each(|x| *x = 0);
533        self.values.clear();
534    }
535}
536
537/// Factory functions for creating observers
538impl Observer {
539    /// Create a MinMax observer
540    pub fn min_max() -> Self {
541        Self::new(ObserverType::MinMax)
542    }
543
544    /// Create a MovingAverage observer
545    pub fn moving_average() -> Self {
546        Self::new(ObserverType::MovingAverage)
547    }
548
549    /// Create a Histogram observer with default bins
550    pub fn histogram() -> Self {
551        Self::new(ObserverType::Histogram)
552    }
553
554    /// Create a Histogram observer with custom number of bins
555    pub fn histogram_with_bins(num_bins: usize) -> Self {
556        Self::new_histogram(num_bins)
557    }
558
559    /// Create a Percentile observer with default percentile
560    pub fn percentile() -> Self {
561        Self::new(ObserverType::Percentile)
562    }
563
564    /// Create a Percentile observer with custom percentile
565    pub fn percentile_with_value(percentile: f32) -> Self {
566        Self::new_percentile(percentile)
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    use torsh_tensor::creation::tensor_1d;
575
576    #[test]
577    fn test_observer_creation() {
578        let minmax_observer = Observer::min_max();
579        assert_eq!(minmax_observer.observer_type(), ObserverType::MinMax);
580
581        let histogram_observer = Observer::histogram_with_bins(128);
582        assert_eq!(histogram_observer.observer_type(), ObserverType::Histogram);
583        assert_eq!(histogram_observer.num_bins, 128);
584
585        let percentile_observer = Observer::percentile_with_value(95.0);
586        assert_eq!(
587            percentile_observer.observer_type(),
588            ObserverType::Percentile
589        );
590        assert_eq!(percentile_observer.percentile, 95.0);
591    }
592
593    #[test]
594    fn test_minmax_observer() {
595        let mut observer = Observer::min_max();
596
597        let data1 = vec![1.0, 2.0, 3.0, 4.0];
598        let tensor1 = tensor_1d(&data1).unwrap();
599        observer.update(&tensor1).unwrap();
600
601        let (min, max) = observer.get_min_max();
602        assert_eq!(min, 1.0);
603        assert_eq!(max, 4.0);
604
605        let data2 = vec![0.5, 5.0];
606        let tensor2 = tensor_1d(&data2).unwrap();
607        observer.update(&tensor2).unwrap();
608
609        let (min, max) = observer.get_min_max();
610        assert_eq!(min, 0.5);
611        assert_eq!(max, 5.0);
612    }
613
614    #[test]
615    fn test_histogram_observer() {
616        let mut observer = Observer::histogram_with_bins(10);
617
618        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
619        let tensor = tensor_1d(&data).unwrap();
620        observer.update(&tensor).unwrap();
621
622        let stats = observer.get_statistics();
623        assert_eq!(stats.get("total_samples"), Some(&5.0));
624        assert_eq!(stats.get("num_bins"), Some(&10.0));
625    }
626
627    #[test]
628    fn test_percentile_observer() {
629        let mut observer = Observer::percentile_with_value(90.0);
630
631        let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
632        let tensor = tensor_1d(&data).unwrap();
633        observer.update(&tensor).unwrap();
634
635        let stats = observer.get_statistics();
636        assert_eq!(stats.get("total_values"), Some(&100.0));
637        assert_eq!(stats.get("percentile"), Some(&90.0));
638    }
639
640    #[test]
641    fn test_calculate_qparams() {
642        let mut observer = Observer::min_max();
643
644        let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
645        let tensor = tensor_1d(&data).unwrap();
646        observer.update(&tensor).unwrap();
647
648        let (scale, zero_point) = observer.calculate_qparams(DType::I8).unwrap();
649        assert!(scale > 0.0);
650        assert!(zero_point >= -128 && zero_point <= 127);
651    }
652
653    #[test]
654    fn test_outlier_detection() {
655        let observer = Observer::min_max();
656        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; // 100.0 is an outlier
657
658        let (cleaned_data, outliers_removed) = observer.detect_outliers(&data, 1.5);
659        assert!(outliers_removed > 0);
660        assert!(cleaned_data.len() < data.len());
661        assert!(!cleaned_data.contains(&100.0));
662    }
663
664    #[test]
665    fn test_observer_reset() {
666        let mut observer = Observer::min_max();
667
668        let data = vec![1.0, 2.0, 3.0];
669        let tensor = tensor_1d(&data).unwrap();
670        observer.update(&tensor).unwrap();
671
672        assert_eq!(observer.num_batches(), 1);
673
674        observer.reset();
675        assert_eq!(observer.num_batches(), 0);
676
677        let (min, max) = observer.get_min_max();
678        assert!(min.is_infinite() && min > 0.0);
679        assert!(max.is_infinite() && max < 0.0);
680    }
681
682    #[test]
683    fn test_invalid_tensor_data() {
684        let mut observer = Observer::min_max();
685
686        let data = vec![f32::NAN, 1.0, 2.0];
687        let tensor = tensor_1d(&data).unwrap();
688
689        let result = observer.update(&tensor);
690        assert!(result.is_err());
691    }
692
693    #[test]
694    fn test_empty_tensor() {
695        let mut observer = Observer::min_max();
696
697        let data: Vec<f32> = vec![];
698        let tensor = tensor_1d(&data).unwrap();
699
700        let result = observer.update(&tensor);
701        assert!(result.is_ok());
702        assert_eq!(observer.num_batches(), 1);
703    }
704
705    #[test]
706    fn test_unsupported_dtype() {
707        let mut observer = Observer::min_max();
708
709        let data = vec![1.0, 2.0, 3.0];
710        let tensor = tensor_1d(&data).unwrap();
711        observer.update(&tensor).unwrap();
712
713        let result = observer.calculate_qparams(DType::F32);
714        assert!(result.is_err());
715    }
716}