Skip to main content

tensorlogic_infer/
tensor_stats.rs

1//! Tensor statistics and anomaly detection for inference monitoring.
2//!
3//! Provides per-tensor statistics (mean, std, percentiles), IQR-based outlier
4//! detection, and activation history tracking for debugging training pipelines.
5
6use std::collections::HashMap;
7use thiserror::Error;
8
9/// Error types for statistics operations.
10#[derive(Debug, Error)]
11pub enum StatsError {
12    /// The input data slice was empty.
13    #[error("Empty data slice")]
14    EmptyData,
15    /// The requested percentile was out of range [0, 1].
16    #[error("Invalid percentile: {0}")]
17    InvalidPercentile(f64),
18}
19
20/// Summary statistics for a tensor.
21#[derive(Debug, Clone)]
22pub struct TensorStats {
23    /// Arithmetic mean of finite values.
24    pub mean: f64,
25    /// Population standard deviation of finite values.
26    pub std: f64,
27    /// Minimum finite value.
28    pub min: f64,
29    /// Maximum finite value.
30    pub max: f64,
31    /// 25th percentile.
32    pub p25: f64,
33    /// 50th percentile (median).
34    pub p50: f64,
35    /// 75th percentile.
36    pub p75: f64,
37    /// Number of NaN values.
38    pub nan_count: usize,
39    /// Number of Inf values (positive or negative).
40    pub inf_count: usize,
41    /// Total number of elements.
42    pub element_count: usize,
43}
44
45impl TensorStats {
46    /// Compute statistics from a slice of f64 values.
47    ///
48    /// NaN and Inf values are counted but excluded from statistical calculations.
49    pub fn compute(data: &[f64]) -> Result<Self, StatsError> {
50        if data.is_empty() {
51            return Err(StatsError::EmptyData);
52        }
53        let nan_count = data.iter().filter(|v| v.is_nan()).count();
54        let inf_count = data.iter().filter(|v| v.is_infinite()).count();
55
56        // Filter to finite values for stats
57        let mut finite: Vec<f64> = data.iter().copied().filter(|v| v.is_finite()).collect();
58        if finite.is_empty() {
59            return Ok(TensorStats {
60                mean: f64::NAN,
61                std: f64::NAN,
62                min: f64::NAN,
63                max: f64::NAN,
64                p25: f64::NAN,
65                p50: f64::NAN,
66                p75: f64::NAN,
67                nan_count,
68                inf_count,
69                element_count: data.len(),
70            });
71        }
72        finite.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
73
74        let n = finite.len() as f64;
75        let mean = finite.iter().sum::<f64>() / n;
76        let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
77        let std = variance.sqrt();
78        let min = finite[0];
79        let max = finite[finite.len() - 1];
80
81        // Percentiles using linear interpolation
82        let p25 = percentile(&finite, 0.25);
83        let p50 = percentile(&finite, 0.50);
84        let p75 = percentile(&finite, 0.75);
85
86        Ok(TensorStats {
87            mean,
88            std,
89            min,
90            max,
91            p25,
92            p50,
93            p75,
94            nan_count,
95            inf_count,
96            element_count: data.len(),
97        })
98    }
99
100    /// Whether any NaN or Inf values are present.
101    pub fn has_anomalies(&self) -> bool {
102        self.nan_count > 0 || self.inf_count > 0
103    }
104
105    /// Interquartile range (p75 - p25).
106    pub fn iqr(&self) -> f64 {
107        self.p75 - self.p25
108    }
109
110    /// Range (max - min).
111    pub fn range(&self) -> f64 {
112        self.max - self.min
113    }
114
115    /// Coefficient of variation (std / |mean|).
116    pub fn cv(&self) -> f64 {
117        if self.mean.abs() < 1e-15 {
118            f64::INFINITY
119        } else {
120            self.std / self.mean.abs()
121        }
122    }
123}
124
125/// Compute a percentile from sorted data using linear interpolation.
126fn percentile(sorted: &[f64], p: f64) -> f64 {
127    if sorted.is_empty() {
128        return f64::NAN;
129    }
130    if sorted.len() == 1 {
131        return sorted[0];
132    }
133    let idx = p * (sorted.len() - 1) as f64;
134    let lo = idx.floor() as usize;
135    let hi = (lo + 1).min(sorted.len() - 1);
136    let frac = idx - lo as f64;
137    sorted[lo] * (1.0 - frac) + sorted[hi] * frac
138}
139
140/// Kind of anomaly detected.
141#[derive(Debug, Clone, PartialEq)]
142pub enum AnomalyKind {
143    /// Not-a-Number value.
144    NaN,
145    /// Infinite value.
146    Inf,
147    /// Statistical outlier (z-score exceeds threshold).
148    Outlier {
149        /// The z-score of the outlier.
150        z_score: f64,
151    },
152    /// All values are identical (could indicate dead neuron).
153    Constant,
154}
155
156/// Report of anomalies found in a tensor.
157#[derive(Debug, Clone)]
158pub struct AnomalyReport {
159    /// List of (element_index, anomaly_kind) pairs.
160    pub anomalies: Vec<(usize, AnomalyKind)>,
161    /// Total anomaly count.
162    pub anomaly_count: usize,
163    /// True if no anomalies found.
164    pub is_clean: bool,
165}
166
167/// Configurable anomaly detector.
168pub struct AnomalyDetector {
169    /// IQR multiplier for outlier detection (default 1.5).
170    pub iqr_multiplier: f64,
171    /// Z-score threshold for outlier detection (default 3.0).
172    pub z_score_threshold: f64,
173    /// Whether to flag constant-valued tensors.
174    pub check_constant: bool,
175}
176
177impl Default for AnomalyDetector {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl AnomalyDetector {
184    /// Create a new anomaly detector with default settings.
185    pub fn new() -> Self {
186        AnomalyDetector {
187            iqr_multiplier: 1.5,
188            z_score_threshold: 3.0,
189            check_constant: true,
190        }
191    }
192
193    /// Set the IQR multiplier for outlier detection.
194    pub fn with_iqr_multiplier(mut self, m: f64) -> Self {
195        self.iqr_multiplier = m;
196        self
197    }
198
199    /// Set the z-score threshold for outlier detection.
200    pub fn with_z_score_threshold(mut self, t: f64) -> Self {
201        self.z_score_threshold = t;
202        self
203    }
204
205    /// Set whether to check for constant-valued tensors.
206    pub fn with_check_constant(mut self, c: bool) -> Self {
207        self.check_constant = c;
208        self
209    }
210
211    /// Detect anomalies in data.
212    pub fn detect(&self, data: &[f64]) -> AnomalyReport {
213        let mut anomalies = Vec::new();
214
215        // Check NaN and Inf
216        for (i, &v) in data.iter().enumerate() {
217            if v.is_nan() {
218                anomalies.push((i, AnomalyKind::NaN));
219            } else if v.is_infinite() {
220                anomalies.push((i, AnomalyKind::Inf));
221            }
222        }
223
224        // Compute stats for outlier detection (finite values only)
225        let finite: Vec<f64> = data.iter().copied().filter(|v| v.is_finite()).collect();
226        if finite.len() >= 2 {
227            let mean = finite.iter().sum::<f64>() / finite.len() as f64;
228            let std = (finite.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
229                / finite.len() as f64)
230                .sqrt();
231
232            if std > 1e-15 {
233                for (i, &v) in data.iter().enumerate() {
234                    if v.is_finite() {
235                        let z = ((v - mean) / std).abs();
236                        if z > self.z_score_threshold {
237                            anomalies.push((i, AnomalyKind::Outlier { z_score: z }));
238                        }
239                    }
240                }
241            }
242
243            // Check constant
244            if self.check_constant && std < 1e-15 {
245                anomalies.push((0, AnomalyKind::Constant));
246            }
247        } else if self.check_constant && finite.len() == 1 && data.len() > 1 {
248            // All same value or only one finite value among many
249            anomalies.push((0, AnomalyKind::Constant));
250        }
251
252        let count = anomalies.len();
253        AnomalyReport {
254            anomalies,
255            anomaly_count: count,
256            is_clean: count == 0,
257        }
258    }
259}
260
261/// Track statistics for named tensors across training steps.
262pub struct ActivationStatistics {
263    history: HashMap<String, Vec<TensorStats>>,
264    max_history: usize,
265}
266
267impl ActivationStatistics {
268    /// Create a new activation statistics tracker with the given history limit.
269    pub fn new(max_history: usize) -> Self {
270        ActivationStatistics {
271            history: HashMap::new(),
272            max_history: max_history.max(1),
273        }
274    }
275
276    /// Record statistics for a named tensor.
277    pub fn record(&mut self, name: &str, data: &[f64]) -> Result<(), StatsError> {
278        let stats = TensorStats::compute(data)?;
279        let entry = self.history.entry(name.to_string()).or_default();
280        entry.push(stats);
281        if entry.len() > self.max_history {
282            entry.remove(0);
283        }
284        Ok(())
285    }
286
287    /// Get the most recent stats for a named tensor.
288    pub fn latest(&self, name: &str) -> Option<&TensorStats> {
289        self.history.get(name).and_then(|v| v.last())
290    }
291
292    /// Get the trend of means over history for a named tensor.
293    pub fn trend_mean(&self, name: &str) -> Option<Vec<f64>> {
294        self.history
295            .get(name)
296            .map(|v| v.iter().map(|s| s.mean).collect())
297    }
298
299    /// Get the trend of stds over history.
300    pub fn trend_std(&self, name: &str) -> Option<Vec<f64>> {
301        self.history
302            .get(name)
303            .map(|v| v.iter().map(|s| s.std).collect())
304    }
305
306    /// Iterator over all tracked tensor names.
307    pub fn names(&self) -> impl Iterator<Item = &String> {
308        self.history.keys()
309    }
310
311    /// Number of tracked tensors.
312    pub fn tracked_count(&self) -> usize {
313        self.history.len()
314    }
315
316    /// Clear all history.
317    pub fn clear(&mut self) {
318        self.history.clear();
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    const EPSILON: f64 = 1e-10;
327
328    fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
329        (a - b).abs() < eps
330    }
331
332    #[test]
333    fn test_stats_basic() {
334        let data = [1.0, 2.0, 3.0, 4.0, 5.0];
335        let stats = TensorStats::compute(&data).expect("compute failed");
336        assert!(approx_eq(stats.mean, 3.0, EPSILON));
337        // Population std = sqrt(2.0)
338        assert!(approx_eq(stats.std, 2.0_f64.sqrt(), 1e-6));
339        assert!(approx_eq(stats.min, 1.0, EPSILON));
340        assert!(approx_eq(stats.max, 5.0, EPSILON));
341    }
342
343    #[test]
344    fn test_stats_percentiles() {
345        let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
346        let stats = TensorStats::compute(&data).expect("compute failed");
347        assert!(approx_eq(stats.p25, 25.75, 1e-6));
348        assert!(approx_eq(stats.p50, 50.5, 1e-6));
349        assert!(approx_eq(stats.p75, 75.25, 1e-6));
350    }
351
352    #[test]
353    fn test_stats_single_element() {
354        let data = [42.0];
355        let stats = TensorStats::compute(&data).expect("compute failed");
356        assert!(approx_eq(stats.mean, 42.0, EPSILON));
357        assert!(approx_eq(stats.std, 0.0, EPSILON));
358        assert!(approx_eq(stats.min, 42.0, EPSILON));
359        assert!(approx_eq(stats.max, 42.0, EPSILON));
360    }
361
362    #[test]
363    fn test_stats_all_same() {
364        let data = [5.0, 5.0, 5.0, 5.0];
365        let stats = TensorStats::compute(&data).expect("compute failed");
366        assert!(approx_eq(stats.std, 0.0, EPSILON));
367        assert!(approx_eq(stats.iqr(), 0.0, EPSILON));
368    }
369
370    #[test]
371    fn test_stats_nan_count() {
372        let data = [1.0, f64::NAN, 3.0];
373        let stats = TensorStats::compute(&data).expect("compute failed");
374        assert_eq!(stats.nan_count, 1);
375        assert!(approx_eq(stats.mean, 2.0, EPSILON));
376    }
377
378    #[test]
379    fn test_stats_inf_count() {
380        let data = [1.0, f64::INFINITY, 3.0];
381        let stats = TensorStats::compute(&data).expect("compute failed");
382        assert_eq!(stats.inf_count, 1);
383    }
384
385    #[test]
386    fn test_stats_has_anomalies() {
387        let data = [1.0, f64::NAN, 3.0];
388        let stats = TensorStats::compute(&data).expect("compute failed");
389        assert!(stats.has_anomalies());
390    }
391
392    #[test]
393    fn test_stats_empty_err() {
394        let data: &[f64] = &[];
395        let result = TensorStats::compute(data);
396        assert!(result.is_err());
397        assert!(matches!(result, Err(StatsError::EmptyData)));
398    }
399
400    #[test]
401    fn test_stats_iqr() {
402        let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
403        let stats = TensorStats::compute(&data).expect("compute failed");
404        let expected_iqr = stats.p75 - stats.p25;
405        assert!(approx_eq(stats.iqr(), expected_iqr, EPSILON));
406    }
407
408    #[test]
409    fn test_stats_cv() {
410        let data = [2.0, 4.0, 6.0, 8.0, 10.0];
411        let stats = TensorStats::compute(&data).expect("compute failed");
412        let expected_cv = stats.std / stats.mean.abs();
413        assert!(approx_eq(stats.cv(), expected_cv, EPSILON));
414    }
415
416    #[test]
417    fn test_anomaly_clean() {
418        let detector = AnomalyDetector::new();
419        let data = [1.0, 2.0, 3.0, 4.0, 5.0];
420        let report = detector.detect(&data);
421        assert!(report.is_clean);
422    }
423
424    #[test]
425    fn test_anomaly_nan() {
426        let detector = AnomalyDetector::new();
427        let data = [f64::NAN];
428        let report = detector.detect(&data);
429        assert!(!report.is_clean);
430        assert!(report
431            .anomalies
432            .iter()
433            .any(|(_, k)| matches!(k, AnomalyKind::NaN)));
434    }
435
436    #[test]
437    fn test_anomaly_inf() {
438        let detector = AnomalyDetector::new();
439        let data = [f64::INFINITY];
440        let report = detector.detect(&data);
441        assert!(!report.is_clean);
442        assert!(report
443            .anomalies
444            .iter()
445            .any(|(_, k)| matches!(k, AnomalyKind::Inf)));
446    }
447
448    #[test]
449    fn test_anomaly_outlier_zscore() {
450        let detector = AnomalyDetector::new().with_z_score_threshold(1.5);
451        let data = [0.0, 0.0, 0.0, 0.0, 100.0];
452        let report = detector.detect(&data);
453        assert!(!report.is_clean);
454        assert!(report
455            .anomalies
456            .iter()
457            .any(|(_, k)| matches!(k, AnomalyKind::Outlier { .. })));
458    }
459
460    #[test]
461    fn test_anomaly_constant() {
462        let detector = AnomalyDetector::new();
463        let data = [7.0, 7.0, 7.0, 7.0];
464        let report = detector.detect(&data);
465        assert!(!report.is_clean);
466        assert!(report
467            .anomalies
468            .iter()
469            .any(|(_, k)| matches!(k, AnomalyKind::Constant)));
470    }
471
472    #[test]
473    fn test_anomaly_no_constant_when_disabled() {
474        let detector = AnomalyDetector::new().with_check_constant(false);
475        let data = [7.0, 7.0, 7.0, 7.0];
476        let report = detector.detect(&data);
477        assert!(report.is_clean);
478    }
479
480    #[test]
481    fn test_activation_record_and_latest() {
482        let mut tracker = ActivationStatistics::new(10);
483        tracker
484            .record("layer1", &[1.0, 2.0, 3.0])
485            .expect("record failed");
486        tracker
487            .record("layer1", &[4.0, 5.0, 6.0])
488            .expect("record failed");
489        tracker
490            .record("layer1", &[7.0, 8.0, 9.0])
491            .expect("record failed");
492        let latest = tracker.latest("layer1").expect("no latest");
493        assert!(approx_eq(latest.mean, 8.0, EPSILON));
494    }
495
496    #[test]
497    fn test_activation_trend_mean() {
498        let mut tracker = ActivationStatistics::new(10);
499        tracker
500            .record("layer1", &[1.0, 2.0, 3.0])
501            .expect("record failed");
502        tracker
503            .record("layer1", &[4.0, 5.0, 6.0])
504            .expect("record failed");
505        tracker
506            .record("layer1", &[7.0, 8.0, 9.0])
507            .expect("record failed");
508        let trend = tracker.trend_mean("layer1").expect("no trend");
509        assert_eq!(trend.len(), 3);
510        assert!(approx_eq(trend[0], 2.0, EPSILON));
511        assert!(approx_eq(trend[1], 5.0, EPSILON));
512        assert!(approx_eq(trend[2], 8.0, EPSILON));
513    }
514
515    #[test]
516    fn test_activation_max_history_cap() {
517        let mut tracker = ActivationStatistics::new(2);
518        for i in 0..5 {
519            let data = [i as f64];
520            tracker.record("layer1", &data).expect("record failed");
521        }
522        let trend = tracker.trend_mean("layer1").expect("no trend");
523        assert_eq!(trend.len(), 2);
524        // Should have last two: 3.0 and 4.0
525        assert!(approx_eq(trend[0], 3.0, EPSILON));
526        assert!(approx_eq(trend[1], 4.0, EPSILON));
527    }
528
529    #[test]
530    fn test_activation_clear() {
531        let mut tracker = ActivationStatistics::new(10);
532        tracker
533            .record("layer1", &[1.0, 2.0])
534            .expect("record failed");
535        tracker
536            .record("layer2", &[3.0, 4.0])
537            .expect("record failed");
538        assert_eq!(tracker.tracked_count(), 2);
539        tracker.clear();
540        assert_eq!(tracker.tracked_count(), 0);
541        assert!(tracker.latest("layer1").is_none());
542    }
543}