1use std::collections::HashMap;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum StatsError {
12 #[error("Empty data slice")]
14 EmptyData,
15 #[error("Invalid percentile: {0}")]
17 InvalidPercentile(f64),
18}
19
20#[derive(Debug, Clone)]
22pub struct TensorStats {
23 pub mean: f64,
25 pub std: f64,
27 pub min: f64,
29 pub max: f64,
31 pub p25: f64,
33 pub p50: f64,
35 pub p75: f64,
37 pub nan_count: usize,
39 pub inf_count: usize,
41 pub element_count: usize,
43}
44
45impl TensorStats {
46 pub fn compute(data: &[f64]) -> Result<Self, StatsError> {
50 if data.is_empty() {
51 return Err(StatsError::EmptyData);
52 }
53 let nan_count = data.iter().filter(|v| v.is_nan()).count();
54 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
55
56 let mut finite: Vec<f64> = data.iter().copied().filter(|v| v.is_finite()).collect();
58 if finite.is_empty() {
59 return Ok(TensorStats {
60 mean: f64::NAN,
61 std: f64::NAN,
62 min: f64::NAN,
63 max: f64::NAN,
64 p25: f64::NAN,
65 p50: f64::NAN,
66 p75: f64::NAN,
67 nan_count,
68 inf_count,
69 element_count: data.len(),
70 });
71 }
72 finite.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
73
74 let n = finite.len() as f64;
75 let mean = finite.iter().sum::<f64>() / n;
76 let variance = finite.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
77 let std = variance.sqrt();
78 let min = finite[0];
79 let max = finite[finite.len() - 1];
80
81 let p25 = percentile(&finite, 0.25);
83 let p50 = percentile(&finite, 0.50);
84 let p75 = percentile(&finite, 0.75);
85
86 Ok(TensorStats {
87 mean,
88 std,
89 min,
90 max,
91 p25,
92 p50,
93 p75,
94 nan_count,
95 inf_count,
96 element_count: data.len(),
97 })
98 }
99
100 pub fn has_anomalies(&self) -> bool {
102 self.nan_count > 0 || self.inf_count > 0
103 }
104
105 pub fn iqr(&self) -> f64 {
107 self.p75 - self.p25
108 }
109
110 pub fn range(&self) -> f64 {
112 self.max - self.min
113 }
114
115 pub fn cv(&self) -> f64 {
117 if self.mean.abs() < 1e-15 {
118 f64::INFINITY
119 } else {
120 self.std / self.mean.abs()
121 }
122 }
123}
124
125fn percentile(sorted: &[f64], p: f64) -> f64 {
127 if sorted.is_empty() {
128 return f64::NAN;
129 }
130 if sorted.len() == 1 {
131 return sorted[0];
132 }
133 let idx = p * (sorted.len() - 1) as f64;
134 let lo = idx.floor() as usize;
135 let hi = (lo + 1).min(sorted.len() - 1);
136 let frac = idx - lo as f64;
137 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
138}
139
140#[derive(Debug, Clone, PartialEq)]
142pub enum AnomalyKind {
143 NaN,
145 Inf,
147 Outlier {
149 z_score: f64,
151 },
152 Constant,
154}
155
156#[derive(Debug, Clone)]
158pub struct AnomalyReport {
159 pub anomalies: Vec<(usize, AnomalyKind)>,
161 pub anomaly_count: usize,
163 pub is_clean: bool,
165}
166
167pub struct AnomalyDetector {
169 pub iqr_multiplier: f64,
171 pub z_score_threshold: f64,
173 pub check_constant: bool,
175}
176
177impl Default for AnomalyDetector {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl AnomalyDetector {
184 pub fn new() -> Self {
186 AnomalyDetector {
187 iqr_multiplier: 1.5,
188 z_score_threshold: 3.0,
189 check_constant: true,
190 }
191 }
192
193 pub fn with_iqr_multiplier(mut self, m: f64) -> Self {
195 self.iqr_multiplier = m;
196 self
197 }
198
199 pub fn with_z_score_threshold(mut self, t: f64) -> Self {
201 self.z_score_threshold = t;
202 self
203 }
204
205 pub fn with_check_constant(mut self, c: bool) -> Self {
207 self.check_constant = c;
208 self
209 }
210
211 pub fn detect(&self, data: &[f64]) -> AnomalyReport {
213 let mut anomalies = Vec::new();
214
215 for (i, &v) in data.iter().enumerate() {
217 if v.is_nan() {
218 anomalies.push((i, AnomalyKind::NaN));
219 } else if v.is_infinite() {
220 anomalies.push((i, AnomalyKind::Inf));
221 }
222 }
223
224 let finite: Vec<f64> = data.iter().copied().filter(|v| v.is_finite()).collect();
226 if finite.len() >= 2 {
227 let mean = finite.iter().sum::<f64>() / finite.len() as f64;
228 let std = (finite.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
229 / finite.len() as f64)
230 .sqrt();
231
232 if std > 1e-15 {
233 for (i, &v) in data.iter().enumerate() {
234 if v.is_finite() {
235 let z = ((v - mean) / std).abs();
236 if z > self.z_score_threshold {
237 anomalies.push((i, AnomalyKind::Outlier { z_score: z }));
238 }
239 }
240 }
241 }
242
243 if self.check_constant && std < 1e-15 {
245 anomalies.push((0, AnomalyKind::Constant));
246 }
247 } else if self.check_constant && finite.len() == 1 && data.len() > 1 {
248 anomalies.push((0, AnomalyKind::Constant));
250 }
251
252 let count = anomalies.len();
253 AnomalyReport {
254 anomalies,
255 anomaly_count: count,
256 is_clean: count == 0,
257 }
258 }
259}
260
261pub struct ActivationStatistics {
263 history: HashMap<String, Vec<TensorStats>>,
264 max_history: usize,
265}
266
267impl ActivationStatistics {
268 pub fn new(max_history: usize) -> Self {
270 ActivationStatistics {
271 history: HashMap::new(),
272 max_history: max_history.max(1),
273 }
274 }
275
276 pub fn record(&mut self, name: &str, data: &[f64]) -> Result<(), StatsError> {
278 let stats = TensorStats::compute(data)?;
279 let entry = self.history.entry(name.to_string()).or_default();
280 entry.push(stats);
281 if entry.len() > self.max_history {
282 entry.remove(0);
283 }
284 Ok(())
285 }
286
287 pub fn latest(&self, name: &str) -> Option<&TensorStats> {
289 self.history.get(name).and_then(|v| v.last())
290 }
291
292 pub fn trend_mean(&self, name: &str) -> Option<Vec<f64>> {
294 self.history
295 .get(name)
296 .map(|v| v.iter().map(|s| s.mean).collect())
297 }
298
299 pub fn trend_std(&self, name: &str) -> Option<Vec<f64>> {
301 self.history
302 .get(name)
303 .map(|v| v.iter().map(|s| s.std).collect())
304 }
305
306 pub fn names(&self) -> impl Iterator<Item = &String> {
308 self.history.keys()
309 }
310
311 pub fn tracked_count(&self) -> usize {
313 self.history.len()
314 }
315
316 pub fn clear(&mut self) {
318 self.history.clear();
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 const EPSILON: f64 = 1e-10;
327
328 fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
329 (a - b).abs() < eps
330 }
331
332 #[test]
333 fn test_stats_basic() {
334 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
335 let stats = TensorStats::compute(&data).expect("compute failed");
336 assert!(approx_eq(stats.mean, 3.0, EPSILON));
337 assert!(approx_eq(stats.std, 2.0_f64.sqrt(), 1e-6));
339 assert!(approx_eq(stats.min, 1.0, EPSILON));
340 assert!(approx_eq(stats.max, 5.0, EPSILON));
341 }
342
343 #[test]
344 fn test_stats_percentiles() {
345 let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
346 let stats = TensorStats::compute(&data).expect("compute failed");
347 assert!(approx_eq(stats.p25, 25.75, 1e-6));
348 assert!(approx_eq(stats.p50, 50.5, 1e-6));
349 assert!(approx_eq(stats.p75, 75.25, 1e-6));
350 }
351
352 #[test]
353 fn test_stats_single_element() {
354 let data = [42.0];
355 let stats = TensorStats::compute(&data).expect("compute failed");
356 assert!(approx_eq(stats.mean, 42.0, EPSILON));
357 assert!(approx_eq(stats.std, 0.0, EPSILON));
358 assert!(approx_eq(stats.min, 42.0, EPSILON));
359 assert!(approx_eq(stats.max, 42.0, EPSILON));
360 }
361
362 #[test]
363 fn test_stats_all_same() {
364 let data = [5.0, 5.0, 5.0, 5.0];
365 let stats = TensorStats::compute(&data).expect("compute failed");
366 assert!(approx_eq(stats.std, 0.0, EPSILON));
367 assert!(approx_eq(stats.iqr(), 0.0, EPSILON));
368 }
369
370 #[test]
371 fn test_stats_nan_count() {
372 let data = [1.0, f64::NAN, 3.0];
373 let stats = TensorStats::compute(&data).expect("compute failed");
374 assert_eq!(stats.nan_count, 1);
375 assert!(approx_eq(stats.mean, 2.0, EPSILON));
376 }
377
378 #[test]
379 fn test_stats_inf_count() {
380 let data = [1.0, f64::INFINITY, 3.0];
381 let stats = TensorStats::compute(&data).expect("compute failed");
382 assert_eq!(stats.inf_count, 1);
383 }
384
385 #[test]
386 fn test_stats_has_anomalies() {
387 let data = [1.0, f64::NAN, 3.0];
388 let stats = TensorStats::compute(&data).expect("compute failed");
389 assert!(stats.has_anomalies());
390 }
391
392 #[test]
393 fn test_stats_empty_err() {
394 let data: &[f64] = &[];
395 let result = TensorStats::compute(data);
396 assert!(result.is_err());
397 assert!(matches!(result, Err(StatsError::EmptyData)));
398 }
399
400 #[test]
401 fn test_stats_iqr() {
402 let data: Vec<f64> = (1..=100).map(|i| i as f64).collect();
403 let stats = TensorStats::compute(&data).expect("compute failed");
404 let expected_iqr = stats.p75 - stats.p25;
405 assert!(approx_eq(stats.iqr(), expected_iqr, EPSILON));
406 }
407
408 #[test]
409 fn test_stats_cv() {
410 let data = [2.0, 4.0, 6.0, 8.0, 10.0];
411 let stats = TensorStats::compute(&data).expect("compute failed");
412 let expected_cv = stats.std / stats.mean.abs();
413 assert!(approx_eq(stats.cv(), expected_cv, EPSILON));
414 }
415
416 #[test]
417 fn test_anomaly_clean() {
418 let detector = AnomalyDetector::new();
419 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
420 let report = detector.detect(&data);
421 assert!(report.is_clean);
422 }
423
424 #[test]
425 fn test_anomaly_nan() {
426 let detector = AnomalyDetector::new();
427 let data = [f64::NAN];
428 let report = detector.detect(&data);
429 assert!(!report.is_clean);
430 assert!(report
431 .anomalies
432 .iter()
433 .any(|(_, k)| matches!(k, AnomalyKind::NaN)));
434 }
435
436 #[test]
437 fn test_anomaly_inf() {
438 let detector = AnomalyDetector::new();
439 let data = [f64::INFINITY];
440 let report = detector.detect(&data);
441 assert!(!report.is_clean);
442 assert!(report
443 .anomalies
444 .iter()
445 .any(|(_, k)| matches!(k, AnomalyKind::Inf)));
446 }
447
448 #[test]
449 fn test_anomaly_outlier_zscore() {
450 let detector = AnomalyDetector::new().with_z_score_threshold(1.5);
451 let data = [0.0, 0.0, 0.0, 0.0, 100.0];
452 let report = detector.detect(&data);
453 assert!(!report.is_clean);
454 assert!(report
455 .anomalies
456 .iter()
457 .any(|(_, k)| matches!(k, AnomalyKind::Outlier { .. })));
458 }
459
460 #[test]
461 fn test_anomaly_constant() {
462 let detector = AnomalyDetector::new();
463 let data = [7.0, 7.0, 7.0, 7.0];
464 let report = detector.detect(&data);
465 assert!(!report.is_clean);
466 assert!(report
467 .anomalies
468 .iter()
469 .any(|(_, k)| matches!(k, AnomalyKind::Constant)));
470 }
471
472 #[test]
473 fn test_anomaly_no_constant_when_disabled() {
474 let detector = AnomalyDetector::new().with_check_constant(false);
475 let data = [7.0, 7.0, 7.0, 7.0];
476 let report = detector.detect(&data);
477 assert!(report.is_clean);
478 }
479
480 #[test]
481 fn test_activation_record_and_latest() {
482 let mut tracker = ActivationStatistics::new(10);
483 tracker
484 .record("layer1", &[1.0, 2.0, 3.0])
485 .expect("record failed");
486 tracker
487 .record("layer1", &[4.0, 5.0, 6.0])
488 .expect("record failed");
489 tracker
490 .record("layer1", &[7.0, 8.0, 9.0])
491 .expect("record failed");
492 let latest = tracker.latest("layer1").expect("no latest");
493 assert!(approx_eq(latest.mean, 8.0, EPSILON));
494 }
495
496 #[test]
497 fn test_activation_trend_mean() {
498 let mut tracker = ActivationStatistics::new(10);
499 tracker
500 .record("layer1", &[1.0, 2.0, 3.0])
501 .expect("record failed");
502 tracker
503 .record("layer1", &[4.0, 5.0, 6.0])
504 .expect("record failed");
505 tracker
506 .record("layer1", &[7.0, 8.0, 9.0])
507 .expect("record failed");
508 let trend = tracker.trend_mean("layer1").expect("no trend");
509 assert_eq!(trend.len(), 3);
510 assert!(approx_eq(trend[0], 2.0, EPSILON));
511 assert!(approx_eq(trend[1], 5.0, EPSILON));
512 assert!(approx_eq(trend[2], 8.0, EPSILON));
513 }
514
515 #[test]
516 fn test_activation_max_history_cap() {
517 let mut tracker = ActivationStatistics::new(2);
518 for i in 0..5 {
519 let data = [i as f64];
520 tracker.record("layer1", &data).expect("record failed");
521 }
522 let trend = tracker.trend_mean("layer1").expect("no trend");
523 assert_eq!(trend.len(), 2);
524 assert!(approx_eq(trend[0], 3.0, EPSILON));
526 assert!(approx_eq(trend[1], 4.0, EPSILON));
527 }
528
529 #[test]
530 fn test_activation_clear() {
531 let mut tracker = ActivationStatistics::new(10);
532 tracker
533 .record("layer1", &[1.0, 2.0])
534 .expect("record failed");
535 tracker
536 .record("layer2", &[3.0, 4.0])
537 .expect("record failed");
538 assert_eq!(tracker.tracked_count(), 2);
539 tracker.clear();
540 assert_eq!(tracker.tracked_count(), 0);
541 assert!(tracker.latest("layer1").is_none());
542 }
543}