Skip to main content

trustformers_debug/utilities/
tensor_analysis.rs

1//! Tensor analysis utilities and statistical functions
2
3use anyhow::Result;
4use scirs2_core::ndarray::*; // SciRS2 Integration Policy - was: use ndarray::{Array, ArrayD};
5use serde::{Deserialize, Serialize};
6
7/// Batch tensor analysis result
8#[derive(Debug, Serialize, Deserialize)]
9pub struct BatchTensorAnalysis {
10    pub individual_results: Vec<TensorAnalysisResult>,
11    pub overall_statistics: TensorStatistics,
12    pub batch_size: usize,
13    pub analysis_timestamp: chrono::DateTime<chrono::Utc>,
14}
15
16/// Individual tensor analysis result
17#[derive(Debug, Serialize, Deserialize)]
18pub struct TensorAnalysisResult {
19    pub tensor_index: usize,
20    pub shape: Vec<usize>,
21    pub statistics: TensorStatistics,
22    pub anomalies: Vec<TensorAnomaly>,
23}
24
25/// Comprehensive tensor statistics
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TensorStatistics {
28    pub count: usize,
29    pub mean: f32,
30    pub std_dev: f32,
31    pub min: f32,
32    pub max: f32,
33    pub median: f32,
34    pub p25: f32,
35    pub p75: f32,
36    pub nan_count: usize,
37    pub inf_count: usize,
38    pub zero_count: usize,
39    pub skewness: f32,
40    pub kurtosis: f32,
41}
42
43impl Default for TensorStatistics {
44    fn default() -> Self {
45        Self {
46            count: 0,
47            mean: 0.0,
48            std_dev: 0.0,
49            min: 0.0,
50            max: 0.0,
51            median: 0.0,
52            p25: 0.0,
53            p75: 0.0,
54            nan_count: 0,
55            inf_count: 0,
56            zero_count: 0,
57            skewness: 0.0,
58            kurtosis: 0.0,
59        }
60    }
61}
62
63impl TensorStatistics {
64    pub fn accumulate(&mut self, other: &TensorStatistics) {
65        self.count += other.count;
66        self.mean += other.mean;
67        self.std_dev += other.std_dev;
68        self.min = self.min.min(other.min);
69        self.max = self.max.max(other.max);
70        self.nan_count += other.nan_count;
71        self.inf_count += other.inf_count;
72        self.zero_count += other.zero_count;
73    }
74
75    pub fn finalize(&mut self, batch_size: usize) {
76        if batch_size > 0 {
77            self.mean /= batch_size as f32;
78            self.std_dev /= batch_size as f32;
79        }
80    }
81}
82
83/// Tensor anomaly detection result
84#[derive(Debug, Serialize, Deserialize)]
85pub struct TensorAnomaly {
86    pub anomaly_type: AnomalyType,
87    pub severity: AnomalySeverity,
88    pub description: String,
89    pub suggested_fix: String,
90}
91
92/// Types of tensor anomalies
93#[derive(Debug, Serialize, Deserialize)]
94pub enum AnomalyType {
95    NanValues,
96    InfiniteValues,
97    ExtremeSkewness,
98    ExtremeKurtosis,
99    DeadNeurons,
100    ExtremeValues,
101    Saturation,
102    Outliers,
103}
104
105/// Severity levels for anomalies
106#[derive(Debug, Serialize, Deserialize)]
107pub enum AnomalySeverity {
108    Low,
109    Medium,
110    High,
111    Critical,
112}
113
114/// Advanced tensor analysis utilities
115pub struct TensorAnalyzer;
116
117impl TensorAnalyzer {
118    /// Batch tensor analysis with statistical insights
119    pub fn analyze_tensors_batch(tensors: &[ArrayD<f32>]) -> Result<BatchTensorAnalysis> {
120        let mut results = Vec::new();
121        let mut overall_stats = TensorStatistics::default();
122
123        for (i, tensor) in tensors.iter().enumerate() {
124            let stats = Self::compute_tensor_statistics(tensor)?;
125            let anomalies = Self::detect_tensor_anomalies(&stats);
126
127            results.push(TensorAnalysisResult {
128                tensor_index: i,
129                shape: tensor.shape().to_vec(),
130                statistics: stats.clone(),
131                anomalies,
132            });
133
134            overall_stats.accumulate(&stats);
135        }
136
137        overall_stats.finalize(tensors.len());
138
139        Ok(BatchTensorAnalysis {
140            individual_results: results,
141            overall_statistics: overall_stats,
142            batch_size: tensors.len(),
143            analysis_timestamp: chrono::Utc::now(),
144        })
145    }
146
147    /// Compute comprehensive statistics for a tensor
148    pub fn compute_tensor_statistics(tensor: &ArrayD<f32>) -> Result<TensorStatistics> {
149        let data: Vec<f32> = tensor.iter().cloned().collect();
150        let count = data.len();
151
152        if count == 0 {
153            return Ok(TensorStatistics::default());
154        }
155
156        // Basic statistics
157        let sum: f32 = data.iter().sum();
158        let mean = sum / count as f32;
159
160        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / count as f32;
161        let std_dev = variance.sqrt();
162
163        // Min/max
164        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
165        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
166
167        // Percentiles
168        let mut sorted_data = data.clone();
169        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
170
171        let median = Self::percentile(&sorted_data, 50.0);
172        let p25 = Self::percentile(&sorted_data, 25.0);
173        let p75 = Self::percentile(&sorted_data, 75.0);
174
175        // Count special values
176        let nan_count = data.iter().filter(|&&x| x.is_nan()).count();
177        let inf_count = data.iter().filter(|&&x| x.is_infinite()).count();
178        let zero_count = data.iter().filter(|&&x| x == 0.0).count();
179
180        // Higher order moments
181        let skewness = Self::compute_skewness(&data, mean, std_dev);
182        let kurtosis = Self::compute_kurtosis(&data, mean, std_dev);
183
184        Ok(TensorStatistics {
185            count,
186            mean,
187            std_dev,
188            min,
189            max,
190            median,
191            p25,
192            p75,
193            nan_count,
194            inf_count,
195            zero_count,
196            skewness,
197            kurtosis,
198        })
199    }
200
201    /// Detect anomalies in tensor statistics
202    pub fn detect_tensor_anomalies(stats: &TensorStatistics) -> Vec<TensorAnomaly> {
203        let mut anomalies = Vec::new();
204
205        // Check for NaN values
206        if stats.nan_count > 0 {
207            anomalies.push(TensorAnomaly {
208                anomaly_type: AnomalyType::NanValues,
209                severity: AnomalySeverity::Critical,
210                description: format!("Found {} NaN values in tensor", stats.nan_count),
211                suggested_fix: "Check for division by zero or invalid operations".to_string(),
212            });
213        }
214
215        // Check for infinite values
216        if stats.inf_count > 0 {
217            anomalies.push(TensorAnomaly {
218                anomaly_type: AnomalyType::InfiniteValues,
219                severity: AnomalySeverity::High,
220                description: format!("Found {} infinite values in tensor", stats.inf_count),
221                suggested_fix: "Check for overflow or division by zero".to_string(),
222            });
223        }
224
225        // Check for extreme skewness
226        if stats.skewness.abs() > 3.0 {
227            anomalies.push(TensorAnomaly {
228                anomaly_type: AnomalyType::ExtremeSkewness,
229                severity: AnomalySeverity::Medium,
230                description: format!("Extreme skewness detected: {:.2}", stats.skewness),
231                suggested_fix: "Consider data normalization or outlier removal".to_string(),
232            });
233        }
234
235        // Check for extreme kurtosis
236        if stats.kurtosis > 10.0 {
237            anomalies.push(TensorAnomaly {
238                anomaly_type: AnomalyType::ExtremeKurtosis,
239                severity: AnomalySeverity::Medium,
240                description: format!("High kurtosis detected: {:.2}", stats.kurtosis),
241                suggested_fix: "Check for outliers or distribution issues".to_string(),
242            });
243        }
244
245        // Check for dead neurons (too many zeros)
246        let zero_ratio = stats.zero_count as f32 / stats.count as f32;
247        if zero_ratio > 0.5 {
248            anomalies.push(TensorAnomaly {
249                anomaly_type: AnomalyType::DeadNeurons,
250                severity: AnomalySeverity::High,
251                description: format!("High zero ratio: {:.2}%", zero_ratio * 100.0),
252                suggested_fix:
253                    "Check learning rate, weight initialization, or activation functions"
254                        .to_string(),
255            });
256        }
257
258        // Check for extreme values
259        let range = stats.max - stats.min;
260        if range > 1000.0 || stats.max.abs() > 100.0 || stats.min.abs() > 100.0 {
261            anomalies.push(TensorAnomaly {
262                anomaly_type: AnomalyType::ExtremeValues,
263                severity: AnomalySeverity::Medium,
264                description: format!("Extreme value range: [{:.2}, {:.2}]", stats.min, stats.max),
265                suggested_fix: "Consider gradient clipping or weight regularization".to_string(),
266            });
267        }
268
269        anomalies
270    }
271
272    /// Calculate percentile of sorted data
273    fn percentile(sorted_data: &[f32], percentile: f32) -> f32 {
274        if sorted_data.is_empty() {
275            return 0.0;
276        }
277
278        let index = (percentile / 100.0) * (sorted_data.len() - 1) as f32;
279        let lower_index = index.floor() as usize;
280        let upper_index = (index.ceil() as usize).min(sorted_data.len() - 1);
281
282        if lower_index == upper_index {
283            sorted_data[lower_index]
284        } else {
285            let weight = index - lower_index as f32;
286            sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
287        }
288    }
289
290    /// Compute skewness
291    fn compute_skewness(data: &[f32], mean: f32, std_dev: f32) -> f32 {
292        if std_dev == 0.0 || data.len() < 3 {
293            return 0.0;
294        }
295
296        let n = data.len() as f32;
297        let skewness = data.iter().map(|&x| ((x - mean) / std_dev).powi(3)).sum::<f32>() / n;
298
299        skewness
300    }
301
302    /// Compute kurtosis
303    fn compute_kurtosis(data: &[f32], mean: f32, std_dev: f32) -> f32 {
304        if std_dev == 0.0 || data.len() < 4 {
305            return 0.0;
306        }
307
308        let n = data.len() as f32;
309        let kurtosis = data.iter().map(|&x| ((x - mean) / std_dev).powi(4)).sum::<f32>() / n;
310
311        kurtosis - 3.0 // Excess kurtosis
312    }
313
314    /// Compare tensors for drift detection
315    pub fn compare_tensors(
316        baseline: &ArrayD<f32>,
317        current: &ArrayD<f32>,
318    ) -> Result<TensorComparisonResult> {
319        let baseline_stats = Self::compute_tensor_statistics(baseline)?;
320        let current_stats = Self::compute_tensor_statistics(current)?;
321
322        // Calculate various drift metrics
323        let mean_drift = (current_stats.mean - baseline_stats.mean).abs();
324        let std_drift = (current_stats.std_dev - baseline_stats.std_dev).abs();
325        let distribution_shift = Self::compute_distribution_shift(&baseline_stats, &current_stats);
326
327        let drift_severity = if mean_drift > 1.0 || std_drift > 1.0 || distribution_shift > 0.5 {
328            TensorDriftSeverity::High
329        } else if mean_drift > 0.5 || std_drift > 0.5 || distribution_shift > 0.3 {
330            TensorDriftSeverity::Medium
331        } else {
332            TensorDriftSeverity::Low
333        };
334
335        Ok(TensorComparisonResult {
336            baseline_stats,
337            current_stats,
338            mean_drift,
339            std_drift,
340            distribution_shift,
341            drift_severity: drift_severity.clone(),
342            recommendations: Self::generate_drift_recommendations(
343                drift_severity,
344                mean_drift,
345                std_drift,
346            ),
347        })
348    }
349
350    /// Compute distribution shift between two sets of statistics
351    fn compute_distribution_shift(baseline: &TensorStatistics, current: &TensorStatistics) -> f32 {
352        // Simple distribution shift metric based on statistical differences
353        let mean_diff = ((current.mean - baseline.mean) / (baseline.std_dev + 1e-8)).abs();
354        let std_diff = ((current.std_dev - baseline.std_dev) / (baseline.std_dev + 1e-8)).abs();
355        let skew_diff = (current.skewness - baseline.skewness).abs();
356
357        (mean_diff + std_diff + skew_diff * 0.5) / 2.5
358    }
359
360    /// Generate recommendations based on drift severity
361    fn generate_drift_recommendations(
362        severity: TensorDriftSeverity,
363        mean_drift: f32,
364        std_drift: f32,
365    ) -> Vec<String> {
366        let mut recommendations = Vec::new();
367
368        match severity {
369            TensorDriftSeverity::High => {
370                recommendations.push("Significant tensor drift detected".to_string());
371                if mean_drift > 1.0 {
372                    recommendations.push("Consider retraining or data rebalancing".to_string());
373                }
374                if std_drift > 1.0 {
375                    recommendations.push("Check for changes in data preprocessing".to_string());
376                }
377            },
378            TensorDriftSeverity::Medium => {
379                recommendations.push("Moderate tensor drift detected".to_string());
380                recommendations.push("Monitor closely for further changes".to_string());
381            },
382            TensorDriftSeverity::Low => {
383                recommendations.push("Minimal tensor drift - within acceptable range".to_string());
384            },
385        }
386
387        recommendations
388    }
389}
390
391/// Result of tensor comparison for drift detection
392#[derive(Debug, Serialize, Deserialize)]
393pub struct TensorComparisonResult {
394    pub baseline_stats: TensorStatistics,
395    pub current_stats: TensorStatistics,
396    pub mean_drift: f32,
397    pub std_drift: f32,
398    pub distribution_shift: f32,
399    pub drift_severity: TensorDriftSeverity,
400    pub recommendations: Vec<String>,
401}
402
403/// Drift severity levels
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub enum TensorDriftSeverity {
406    Low,
407    Medium,
408    High,
409}