1use crate::config::ObserverType;
18
19#[cfg(feature = "std")]
20use std::collections::HashMap;
21
22#[cfg(not(feature = "std"))]
23extern crate alloc;
24
25#[cfg(not(feature = "std"))]
26use alloc::{collections::BTreeMap as HashMap, string::String, vec::Vec};
27
28use torsh_core::{
29 dtype::DType,
30 error::{Result as TorshResult, TorshError},
31};
32use torsh_tensor::Tensor;
33
34#[derive(Debug)]
36pub struct Observer {
37 observer_type: ObserverType,
38 min_val: f32,
39 max_val: f32,
40 num_batches: usize,
41 #[allow(dead_code)]
43 avg_min: f32,
44 #[allow(dead_code)]
45 avg_max: f32,
46 histogram: Vec<usize>,
48 hist_min: f32,
49 hist_max: f32,
50 num_bins: usize,
51 values: Vec<f32>,
53 percentile: f32,
54}
55
56impl Observer {
57 pub fn new(observer_type: ObserverType) -> Self {
59 Self {
60 observer_type,
61 min_val: f32::INFINITY,
62 max_val: f32::NEG_INFINITY,
63 num_batches: 0,
64 avg_min: 0.0,
65 avg_max: 0.0,
66 histogram: vec![0; 256], hist_min: f32::INFINITY,
68 hist_max: f32::NEG_INFINITY,
69 num_bins: 256,
70 values: Vec::new(),
71 percentile: 99.99, }
73 }
74
75 pub fn new_histogram(num_bins: usize) -> Self {
77 Self {
78 observer_type: ObserverType::Histogram,
79 min_val: f32::INFINITY,
80 max_val: f32::NEG_INFINITY,
81 num_batches: 0,
82 avg_min: 0.0,
83 avg_max: 0.0,
84 histogram: vec![0; num_bins],
85 hist_min: f32::INFINITY,
86 hist_max: f32::NEG_INFINITY,
87 num_bins,
88 values: Vec::new(),
89 percentile: 99.99,
90 }
91 }
92
93 pub fn new_percentile(percentile: f32) -> Self {
95 Self {
96 observer_type: ObserverType::Percentile,
97 min_val: f32::INFINITY,
98 max_val: f32::NEG_INFINITY,
99 num_batches: 0,
100 avg_min: 0.0,
101 avg_max: 0.0,
102 histogram: Vec::new(),
103 hist_min: f32::INFINITY,
104 hist_max: f32::NEG_INFINITY,
105 num_bins: 0,
106 values: Vec::new(),
107 percentile,
108 }
109 }
110
111 pub fn update(&mut self, tensor: &Tensor) -> TorshResult<()> {
113 let data = tensor.data()?;
114
115 self.num_batches += 1;
117
118 if data.is_empty() {
119 return Ok(());
120 }
121
122 if data.iter().any(|&x| !x.is_finite()) {
124 return Err(TorshError::InvalidArgument(
125 "Tensor contains non-finite values (NaN or infinity)".to_string(),
126 ));
127 }
128
129 let (batch_min, batch_max) = if data.len() > 10000 {
131 #[cfg(feature = "std")]
132 {
133 use scirs2_core::parallel_ops::*;
134 data.par_iter().map(|&x| (x, x)).reduce(
135 || (f32::INFINITY, f32::NEG_INFINITY),
136 |(min1, max1), (min2, max2)| (min1.min(min2), max1.max(max2)),
137 )
138 }
139 #[cfg(not(feature = "std"))]
140 {
141 data.iter()
142 .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), &val| {
143 (min.min(val), max.max(val))
144 })
145 }
146 } else {
147 data.iter()
148 .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), &val| {
149 (min.min(val), max.max(val))
150 })
151 };
152
153 match self.observer_type {
154 ObserverType::MinMax => {
155 self.min_val = self.min_val.min(batch_min);
156 self.max_val = self.max_val.max(batch_max);
157 }
158 ObserverType::MovingAverage => {
159 if self.num_batches == 0 {
160 self.min_val = batch_min;
161 self.max_val = batch_max;
162 self.avg_min = batch_min;
163 self.avg_max = batch_max;
164 } else {
165 let alpha = 0.01; self.avg_min = alpha * batch_min + (1.0 - alpha) * self.avg_min;
167 self.avg_max = alpha * batch_max + (1.0 - alpha) * self.avg_max;
168 self.min_val = self.min_val.min(batch_min);
170 self.max_val = self.max_val.max(batch_max);
171 }
172 }
173 ObserverType::Histogram => {
174 self.min_val = self.min_val.min(batch_min);
176 self.max_val = self.max_val.max(batch_max);
177
178 if self.num_batches == 0 {
180 self.hist_min = batch_min;
181 self.hist_max = batch_max;
182 } else {
183 self.hist_min = self.hist_min.min(batch_min);
184 self.hist_max = self.hist_max.max(batch_max);
185 }
186
187 if data.len() > 5000 {
189 #[cfg(feature = "std")]
191 {
192 use scirs2_core::parallel_ops::*;
193 let local_histograms: Vec<Vec<usize>> = data
194 .par_chunks(1000)
195 .map(|chunk| {
196 let mut local_hist = vec![0; self.num_bins];
197 for &value in chunk {
198 let bin_idx = self.value_to_bin_index(value);
199 if bin_idx < local_hist.len() {
200 local_hist[bin_idx] += 1;
201 }
202 }
203 local_hist
204 })
205 .collect();
206
207 for local_hist in local_histograms {
209 for (i, count) in local_hist.iter().enumerate() {
210 self.histogram[i] += count;
211 }
212 }
213 }
214 #[cfg(not(feature = "std"))]
215 {
216 for &value in data.iter() {
217 let bin_idx = self.value_to_bin_index(value);
218 if bin_idx < self.histogram.len() {
219 self.histogram[bin_idx] += 1;
220 }
221 }
222 }
223 } else {
224 for &value in data.iter() {
225 let bin_idx = self.value_to_bin_index(value);
226 if bin_idx < self.histogram.len() {
227 self.histogram[bin_idx] += 1;
228 }
229 }
230 }
231 }
232 ObserverType::Percentile => {
233 self.min_val = self.min_val.min(batch_min);
235 self.max_val = self.max_val.max(batch_max);
236
237 if self.values.len() + data.len() > 100_000 {
239 let sample_rate = 100_000.0 / (self.values.len() + data.len()) as f32;
241 let sampled_data: Vec<f32> = data
242 .iter()
243 .enumerate()
244 .filter(|(i, _)| (*i as f32 * sample_rate) % 1.0 < sample_rate)
245 .map(|(_, &val)| val)
246 .collect();
247 self.values.extend(sampled_data);
248 } else {
249 self.values.extend(data.iter().cloned());
250 }
251 }
252 _ => {
253 self.min_val = self.min_val.min(batch_min);
255 self.max_val = self.max_val.max(batch_max);
256 }
257 }
258
259 Ok(())
260 }
261
262 pub fn calculate_qparams(&self, dtype: DType) -> TorshResult<(f32, i32)> {
264 let (qmin, qmax) = match dtype {
265 DType::I8 => (-128, 127),
266 DType::U8 => (0, 255),
267 _ => {
268 return Err(TorshError::InvalidArgument(
269 "Unsupported quantization dtype".to_string(),
270 ))
271 }
272 };
273
274 let (min_val, max_val) = match self.observer_type {
276 ObserverType::Histogram => {
277 if !self.histogram.is_empty() {
278 self.calculate_histogram_range()
279 } else {
280 (self.min_val.min(0.0), self.max_val.max(0.0))
281 }
282 }
283 ObserverType::Percentile => {
284 if !self.values.is_empty() {
285 self.calculate_percentile_range()
286 } else {
287 (self.min_val.min(0.0), self.max_val.max(0.0))
288 }
289 }
290 _ => (self.min_val.min(0.0), self.max_val.max(0.0)),
291 };
292
293 let scale = (max_val - min_val) / (qmax - qmin) as f32;
294 let scale = if scale == 0.0 { 1.0 } else { scale };
295
296 let zero_point = (qmin as f32 - min_val / scale)
297 .round()
298 .max(qmin as f32)
299 .min(qmax as f32) as i32;
300
301 Ok((scale, zero_point))
302 }
303
304 fn value_to_bin_index(&self, value: f32) -> usize {
306 let range_min = if self.hist_min.is_finite() {
308 self.hist_min
309 } else {
310 self.min_val
311 };
312 let range_max = if self.hist_max.is_finite() {
313 self.hist_max
314 } else {
315 self.max_val
316 };
317
318 if range_max <= range_min || !value.is_finite() {
319 return 0;
320 }
321
322 let ratio = ((value - range_min) / (range_max - range_min)).clamp(0.0, 1.0);
323 let idx = (ratio * self.num_bins as f32).floor() as usize;
324 idx.min(self.num_bins - 1)
325 }
326
327 fn calculate_histogram_range(&self) -> (f32, f32) {
329 if self.histogram.is_empty() || self.num_bins == 0 {
330 return (self.min_val, self.max_val);
331 }
332
333 let total_samples: usize = self.histogram.iter().sum();
334 if total_samples == 0 {
335 return (self.min_val, self.max_val);
336 }
337
338 let outlier_threshold = if total_samples > 10000 {
340 0.001 } else if total_samples > 1000 {
342 0.005 } else {
344 0.01 };
346
347 let threshold_count = (total_samples as f32 * outlier_threshold) as usize;
348 let mut cumsum = 0;
349 let mut start_bin = 0;
350 let mut end_bin = self.num_bins - 1;
351
352 for (i, &count) in self.histogram.iter().enumerate() {
354 cumsum += count;
355 if cumsum > threshold_count {
356 start_bin = i;
357 break;
358 }
359 }
360
361 cumsum = 0;
363 for (i, &count) in self.histogram.iter().enumerate().rev() {
364 cumsum += count;
365 if cumsum > threshold_count {
366 end_bin = i;
367 break;
368 }
369 }
370
371 if start_bin >= end_bin {
373 return (self.min_val, self.max_val);
374 }
375
376 let range_min = if self.hist_min.is_finite() {
377 self.hist_min
378 } else {
379 self.min_val
380 };
381 let range_max = if self.hist_max.is_finite() {
382 self.hist_max
383 } else {
384 self.max_val
385 };
386
387 if range_max <= range_min {
388 return (self.min_val, self.max_val);
389 }
390
391 let bin_width = (range_max - range_min) / self.num_bins as f32;
392 let min_val = range_min + start_bin as f32 * bin_width;
393 let max_val = range_min + (end_bin + 1) as f32 * bin_width;
394
395 if min_val >= max_val {
397 (self.min_val, self.max_val)
398 } else {
399 (min_val.max(self.min_val), max_val.min(self.max_val))
400 }
401 }
402
403 fn calculate_percentile_range(&self) -> (f32, f32) {
405 if self.values.is_empty() {
406 return (self.min_val, self.max_val);
407 }
408
409 let mut sorted_values = self.values.clone();
410 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
411
412 let n = sorted_values.len();
413 let lower_percentile = 100.0 - self.percentile;
414 let upper_percentile = self.percentile;
415
416 let lower_idx = ((lower_percentile / 100.0) * n as f32) as usize;
417 let upper_idx = ((upper_percentile / 100.0) * n as f32) as usize;
418
419 let lower_idx = lower_idx.min(n - 1);
420 let upper_idx = upper_idx.min(n - 1);
421
422 (sorted_values[lower_idx], sorted_values[upper_idx])
423 }
424
425 pub fn detect_outliers(&self, data: &[f32], factor: f32) -> (Vec<f32>, usize) {
427 if data.is_empty() {
428 return (Vec::new(), 0);
429 }
430
431 let mut sorted_data = data.to_vec();
432 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
433
434 let n = sorted_data.len();
435
436 let q1 = if n >= 4 {
438 let idx = (n as f32 * 0.25) as usize;
439 if idx > 0 {
440 sorted_data[idx.min(n - 1)]
441 } else {
442 sorted_data[0]
443 }
444 } else {
445 sorted_data[0]
446 };
447
448 let q3 = if n >= 4 {
449 let idx = (n as f32 * 0.75) as usize;
450 sorted_data[idx.min(n - 1)]
451 } else {
452 sorted_data[n - 1]
453 };
454
455 let iqr = q3 - q1;
456
457 if iqr < 1e-6 {
459 return (sorted_data, 0);
460 }
461
462 let lower_bound = q1 - factor * iqr;
463 let upper_bound = q3 + factor * iqr;
464
465 let original_len = data.len();
466 let cleaned_data: Vec<f32> = data
467 .iter()
468 .filter(|&&x| x >= lower_bound && x <= upper_bound)
469 .cloned()
470 .collect();
471
472 let outliers_removed = original_len - cleaned_data.len();
473
474 (cleaned_data, outliers_removed)
475 }
476
477 pub fn get_statistics(&self) -> HashMap<String, f32> {
479 let mut stats = HashMap::new();
480
481 stats.insert("min_val".to_string(), self.min_val);
482 stats.insert("max_val".to_string(), self.max_val);
483 stats.insert("range".to_string(), self.max_val - self.min_val);
484 stats.insert("num_batches".to_string(), self.num_batches as f32);
485
486 match self.observer_type {
487 ObserverType::Histogram => {
488 stats.insert("num_bins".to_string(), self.num_bins as f32);
489 stats.insert(
490 "total_samples".to_string(),
491 self.histogram.iter().sum::<usize>() as f32,
492 );
493 if !self.histogram.is_empty() {
494 let max_bin_count = *self.histogram.iter().max().unwrap_or(&0);
495 stats.insert("max_bin_count".to_string(), max_bin_count as f32);
496 }
497 }
498 ObserverType::Percentile => {
499 stats.insert("total_values".to_string(), self.values.len() as f32);
500 stats.insert("percentile".to_string(), self.percentile);
501 }
502 _ => {}
503 }
504
505 stats
506 }
507
508 pub fn observer_type(&self) -> ObserverType {
510 self.observer_type
511 }
512
513 pub fn get_min_max(&self) -> (f32, f32) {
515 (self.min_val, self.max_val)
516 }
517
518 pub fn num_batches(&self) -> usize {
520 self.num_batches
521 }
522
523 pub fn reset(&mut self) {
525 self.min_val = f32::INFINITY;
526 self.max_val = f32::NEG_INFINITY;
527 self.num_batches = 0;
528 self.avg_min = 0.0;
529 self.avg_max = 0.0;
530 self.hist_min = f32::INFINITY;
531 self.hist_max = f32::NEG_INFINITY;
532 self.histogram.iter_mut().for_each(|x| *x = 0);
533 self.values.clear();
534 }
535}
536
537impl Observer {
539 pub fn min_max() -> Self {
541 Self::new(ObserverType::MinMax)
542 }
543
544 pub fn moving_average() -> Self {
546 Self::new(ObserverType::MovingAverage)
547 }
548
549 pub fn histogram() -> Self {
551 Self::new(ObserverType::Histogram)
552 }
553
554 pub fn histogram_with_bins(num_bins: usize) -> Self {
556 Self::new_histogram(num_bins)
557 }
558
559 pub fn percentile() -> Self {
561 Self::new(ObserverType::Percentile)
562 }
563
564 pub fn percentile_with_value(percentile: f32) -> Self {
566 Self::new_percentile(percentile)
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 use torsh_tensor::creation::tensor_1d;
575
576 #[test]
577 fn test_observer_creation() {
578 let minmax_observer = Observer::min_max();
579 assert_eq!(minmax_observer.observer_type(), ObserverType::MinMax);
580
581 let histogram_observer = Observer::histogram_with_bins(128);
582 assert_eq!(histogram_observer.observer_type(), ObserverType::Histogram);
583 assert_eq!(histogram_observer.num_bins, 128);
584
585 let percentile_observer = Observer::percentile_with_value(95.0);
586 assert_eq!(
587 percentile_observer.observer_type(),
588 ObserverType::Percentile
589 );
590 assert_eq!(percentile_observer.percentile, 95.0);
591 }
592
593 #[test]
594 fn test_minmax_observer() {
595 let mut observer = Observer::min_max();
596
597 let data1 = vec![1.0, 2.0, 3.0, 4.0];
598 let tensor1 = tensor_1d(&data1).unwrap();
599 observer.update(&tensor1).unwrap();
600
601 let (min, max) = observer.get_min_max();
602 assert_eq!(min, 1.0);
603 assert_eq!(max, 4.0);
604
605 let data2 = vec![0.5, 5.0];
606 let tensor2 = tensor_1d(&data2).unwrap();
607 observer.update(&tensor2).unwrap();
608
609 let (min, max) = observer.get_min_max();
610 assert_eq!(min, 0.5);
611 assert_eq!(max, 5.0);
612 }
613
614 #[test]
615 fn test_histogram_observer() {
616 let mut observer = Observer::histogram_with_bins(10);
617
618 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
619 let tensor = tensor_1d(&data).unwrap();
620 observer.update(&tensor).unwrap();
621
622 let stats = observer.get_statistics();
623 assert_eq!(stats.get("total_samples"), Some(&5.0));
624 assert_eq!(stats.get("num_bins"), Some(&10.0));
625 }
626
627 #[test]
628 fn test_percentile_observer() {
629 let mut observer = Observer::percentile_with_value(90.0);
630
631 let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
632 let tensor = tensor_1d(&data).unwrap();
633 observer.update(&tensor).unwrap();
634
635 let stats = observer.get_statistics();
636 assert_eq!(stats.get("total_values"), Some(&100.0));
637 assert_eq!(stats.get("percentile"), Some(&90.0));
638 }
639
640 #[test]
641 fn test_calculate_qparams() {
642 let mut observer = Observer::min_max();
643
644 let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
645 let tensor = tensor_1d(&data).unwrap();
646 observer.update(&tensor).unwrap();
647
648 let (scale, zero_point) = observer.calculate_qparams(DType::I8).unwrap();
649 assert!(scale > 0.0);
650 assert!(zero_point >= -128 && zero_point <= 127);
651 }
652
653 #[test]
654 fn test_outlier_detection() {
655 let observer = Observer::min_max();
656 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; let (cleaned_data, outliers_removed) = observer.detect_outliers(&data, 1.5);
659 assert!(outliers_removed > 0);
660 assert!(cleaned_data.len() < data.len());
661 assert!(!cleaned_data.contains(&100.0));
662 }
663
664 #[test]
665 fn test_observer_reset() {
666 let mut observer = Observer::min_max();
667
668 let data = vec![1.0, 2.0, 3.0];
669 let tensor = tensor_1d(&data).unwrap();
670 observer.update(&tensor).unwrap();
671
672 assert_eq!(observer.num_batches(), 1);
673
674 observer.reset();
675 assert_eq!(observer.num_batches(), 0);
676
677 let (min, max) = observer.get_min_max();
678 assert!(min.is_infinite() && min > 0.0);
679 assert!(max.is_infinite() && max < 0.0);
680 }
681
682 #[test]
683 fn test_invalid_tensor_data() {
684 let mut observer = Observer::min_max();
685
686 let data = vec![f32::NAN, 1.0, 2.0];
687 let tensor = tensor_1d(&data).unwrap();
688
689 let result = observer.update(&tensor);
690 assert!(result.is_err());
691 }
692
693 #[test]
694 fn test_empty_tensor() {
695 let mut observer = Observer::min_max();
696
697 let data: Vec<f32> = vec![];
698 let tensor = tensor_1d(&data).unwrap();
699
700 let result = observer.update(&tensor);
701 assert!(result.is_ok());
702 assert_eq!(observer.num_batches(), 1);
703 }
704
705 #[test]
706 fn test_unsupported_dtype() {
707 let mut observer = Observer::min_max();
708
709 let data = vec![1.0, 2.0, 3.0];
710 let tensor = tensor_1d(&data).unwrap();
711 observer.update(&tensor).unwrap();
712
713 let result = observer.calculate_qparams(DType::F32);
714 assert!(result.is_err());
715 }
716}