1use anyhow::Result;
4use scirs2_core::ndarray::*; use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Serialize, Deserialize)]
9pub struct BatchTensorAnalysis {
10 pub individual_results: Vec<TensorAnalysisResult>,
11 pub overall_statistics: TensorStatistics,
12 pub batch_size: usize,
13 pub analysis_timestamp: chrono::DateTime<chrono::Utc>,
14}
15
16#[derive(Debug, Serialize, Deserialize)]
18pub struct TensorAnalysisResult {
19 pub tensor_index: usize,
20 pub shape: Vec<usize>,
21 pub statistics: TensorStatistics,
22 pub anomalies: Vec<TensorAnomaly>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TensorStatistics {
28 pub count: usize,
29 pub mean: f32,
30 pub std_dev: f32,
31 pub min: f32,
32 pub max: f32,
33 pub median: f32,
34 pub p25: f32,
35 pub p75: f32,
36 pub nan_count: usize,
37 pub inf_count: usize,
38 pub zero_count: usize,
39 pub skewness: f32,
40 pub kurtosis: f32,
41}
42
43impl Default for TensorStatistics {
44 fn default() -> Self {
45 Self {
46 count: 0,
47 mean: 0.0,
48 std_dev: 0.0,
49 min: 0.0,
50 max: 0.0,
51 median: 0.0,
52 p25: 0.0,
53 p75: 0.0,
54 nan_count: 0,
55 inf_count: 0,
56 zero_count: 0,
57 skewness: 0.0,
58 kurtosis: 0.0,
59 }
60 }
61}
62
63impl TensorStatistics {
64 pub fn accumulate(&mut self, other: &TensorStatistics) {
65 self.count += other.count;
66 self.mean += other.mean;
67 self.std_dev += other.std_dev;
68 self.min = self.min.min(other.min);
69 self.max = self.max.max(other.max);
70 self.nan_count += other.nan_count;
71 self.inf_count += other.inf_count;
72 self.zero_count += other.zero_count;
73 }
74
75 pub fn finalize(&mut self, batch_size: usize) {
76 if batch_size > 0 {
77 self.mean /= batch_size as f32;
78 self.std_dev /= batch_size as f32;
79 }
80 }
81}
82
83#[derive(Debug, Serialize, Deserialize)]
85pub struct TensorAnomaly {
86 pub anomaly_type: AnomalyType,
87 pub severity: AnomalySeverity,
88 pub description: String,
89 pub suggested_fix: String,
90}
91
92#[derive(Debug, Serialize, Deserialize)]
94pub enum AnomalyType {
95 NanValues,
96 InfiniteValues,
97 ExtremeSkewness,
98 ExtremeKurtosis,
99 DeadNeurons,
100 ExtremeValues,
101 Saturation,
102 Outliers,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
107pub enum AnomalySeverity {
108 Low,
109 Medium,
110 High,
111 Critical,
112}
113
114pub struct TensorAnalyzer;
116
117impl TensorAnalyzer {
118 pub fn analyze_tensors_batch(tensors: &[ArrayD<f32>]) -> Result<BatchTensorAnalysis> {
120 let mut results = Vec::new();
121 let mut overall_stats = TensorStatistics::default();
122
123 for (i, tensor) in tensors.iter().enumerate() {
124 let stats = Self::compute_tensor_statistics(tensor)?;
125 let anomalies = Self::detect_tensor_anomalies(&stats);
126
127 results.push(TensorAnalysisResult {
128 tensor_index: i,
129 shape: tensor.shape().to_vec(),
130 statistics: stats.clone(),
131 anomalies,
132 });
133
134 overall_stats.accumulate(&stats);
135 }
136
137 overall_stats.finalize(tensors.len());
138
139 Ok(BatchTensorAnalysis {
140 individual_results: results,
141 overall_statistics: overall_stats,
142 batch_size: tensors.len(),
143 analysis_timestamp: chrono::Utc::now(),
144 })
145 }
146
147 pub fn compute_tensor_statistics(tensor: &ArrayD<f32>) -> Result<TensorStatistics> {
149 let data: Vec<f32> = tensor.iter().cloned().collect();
150 let count = data.len();
151
152 if count == 0 {
153 return Ok(TensorStatistics::default());
154 }
155
156 let sum: f32 = data.iter().sum();
158 let mean = sum / count as f32;
159
160 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / count as f32;
161 let std_dev = variance.sqrt();
162
163 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
165 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
166
167 let mut sorted_data = data.clone();
169 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
170
171 let median = Self::percentile(&sorted_data, 50.0);
172 let p25 = Self::percentile(&sorted_data, 25.0);
173 let p75 = Self::percentile(&sorted_data, 75.0);
174
175 let nan_count = data.iter().filter(|&&x| x.is_nan()).count();
177 let inf_count = data.iter().filter(|&&x| x.is_infinite()).count();
178 let zero_count = data.iter().filter(|&&x| x == 0.0).count();
179
180 let skewness = Self::compute_skewness(&data, mean, std_dev);
182 let kurtosis = Self::compute_kurtosis(&data, mean, std_dev);
183
184 Ok(TensorStatistics {
185 count,
186 mean,
187 std_dev,
188 min,
189 max,
190 median,
191 p25,
192 p75,
193 nan_count,
194 inf_count,
195 zero_count,
196 skewness,
197 kurtosis,
198 })
199 }
200
201 pub fn detect_tensor_anomalies(stats: &TensorStatistics) -> Vec<TensorAnomaly> {
203 let mut anomalies = Vec::new();
204
205 if stats.nan_count > 0 {
207 anomalies.push(TensorAnomaly {
208 anomaly_type: AnomalyType::NanValues,
209 severity: AnomalySeverity::Critical,
210 description: format!("Found {} NaN values in tensor", stats.nan_count),
211 suggested_fix: "Check for division by zero or invalid operations".to_string(),
212 });
213 }
214
215 if stats.inf_count > 0 {
217 anomalies.push(TensorAnomaly {
218 anomaly_type: AnomalyType::InfiniteValues,
219 severity: AnomalySeverity::High,
220 description: format!("Found {} infinite values in tensor", stats.inf_count),
221 suggested_fix: "Check for overflow or division by zero".to_string(),
222 });
223 }
224
225 if stats.skewness.abs() > 3.0 {
227 anomalies.push(TensorAnomaly {
228 anomaly_type: AnomalyType::ExtremeSkewness,
229 severity: AnomalySeverity::Medium,
230 description: format!("Extreme skewness detected: {:.2}", stats.skewness),
231 suggested_fix: "Consider data normalization or outlier removal".to_string(),
232 });
233 }
234
235 if stats.kurtosis > 10.0 {
237 anomalies.push(TensorAnomaly {
238 anomaly_type: AnomalyType::ExtremeKurtosis,
239 severity: AnomalySeverity::Medium,
240 description: format!("High kurtosis detected: {:.2}", stats.kurtosis),
241 suggested_fix: "Check for outliers or distribution issues".to_string(),
242 });
243 }
244
245 let zero_ratio = stats.zero_count as f32 / stats.count as f32;
247 if zero_ratio > 0.5 {
248 anomalies.push(TensorAnomaly {
249 anomaly_type: AnomalyType::DeadNeurons,
250 severity: AnomalySeverity::High,
251 description: format!("High zero ratio: {:.2}%", zero_ratio * 100.0),
252 suggested_fix:
253 "Check learning rate, weight initialization, or activation functions"
254 .to_string(),
255 });
256 }
257
258 let range = stats.max - stats.min;
260 if range > 1000.0 || stats.max.abs() > 100.0 || stats.min.abs() > 100.0 {
261 anomalies.push(TensorAnomaly {
262 anomaly_type: AnomalyType::ExtremeValues,
263 severity: AnomalySeverity::Medium,
264 description: format!("Extreme value range: [{:.2}, {:.2}]", stats.min, stats.max),
265 suggested_fix: "Consider gradient clipping or weight regularization".to_string(),
266 });
267 }
268
269 anomalies
270 }
271
272 fn percentile(sorted_data: &[f32], percentile: f32) -> f32 {
274 if sorted_data.is_empty() {
275 return 0.0;
276 }
277
278 let index = (percentile / 100.0) * (sorted_data.len() - 1) as f32;
279 let lower_index = index.floor() as usize;
280 let upper_index = (index.ceil() as usize).min(sorted_data.len() - 1);
281
282 if lower_index == upper_index {
283 sorted_data[lower_index]
284 } else {
285 let weight = index - lower_index as f32;
286 sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
287 }
288 }
289
290 fn compute_skewness(data: &[f32], mean: f32, std_dev: f32) -> f32 {
292 if std_dev == 0.0 || data.len() < 3 {
293 return 0.0;
294 }
295
296 let n = data.len() as f32;
297 let skewness = data.iter().map(|&x| ((x - mean) / std_dev).powi(3)).sum::<f32>() / n;
298
299 skewness
300 }
301
302 fn compute_kurtosis(data: &[f32], mean: f32, std_dev: f32) -> f32 {
304 if std_dev == 0.0 || data.len() < 4 {
305 return 0.0;
306 }
307
308 let n = data.len() as f32;
309 let kurtosis = data.iter().map(|&x| ((x - mean) / std_dev).powi(4)).sum::<f32>() / n;
310
311 kurtosis - 3.0 }
313
314 pub fn compare_tensors(
316 baseline: &ArrayD<f32>,
317 current: &ArrayD<f32>,
318 ) -> Result<TensorComparisonResult> {
319 let baseline_stats = Self::compute_tensor_statistics(baseline)?;
320 let current_stats = Self::compute_tensor_statistics(current)?;
321
322 let mean_drift = (current_stats.mean - baseline_stats.mean).abs();
324 let std_drift = (current_stats.std_dev - baseline_stats.std_dev).abs();
325 let distribution_shift = Self::compute_distribution_shift(&baseline_stats, ¤t_stats);
326
327 let drift_severity = if mean_drift > 1.0 || std_drift > 1.0 || distribution_shift > 0.5 {
328 TensorDriftSeverity::High
329 } else if mean_drift > 0.5 || std_drift > 0.5 || distribution_shift > 0.3 {
330 TensorDriftSeverity::Medium
331 } else {
332 TensorDriftSeverity::Low
333 };
334
335 Ok(TensorComparisonResult {
336 baseline_stats,
337 current_stats,
338 mean_drift,
339 std_drift,
340 distribution_shift,
341 drift_severity: drift_severity.clone(),
342 recommendations: Self::generate_drift_recommendations(
343 drift_severity,
344 mean_drift,
345 std_drift,
346 ),
347 })
348 }
349
350 fn compute_distribution_shift(baseline: &TensorStatistics, current: &TensorStatistics) -> f32 {
352 let mean_diff = ((current.mean - baseline.mean) / (baseline.std_dev + 1e-8)).abs();
354 let std_diff = ((current.std_dev - baseline.std_dev) / (baseline.std_dev + 1e-8)).abs();
355 let skew_diff = (current.skewness - baseline.skewness).abs();
356
357 (mean_diff + std_diff + skew_diff * 0.5) / 2.5
358 }
359
360 fn generate_drift_recommendations(
362 severity: TensorDriftSeverity,
363 mean_drift: f32,
364 std_drift: f32,
365 ) -> Vec<String> {
366 let mut recommendations = Vec::new();
367
368 match severity {
369 TensorDriftSeverity::High => {
370 recommendations.push("Significant tensor drift detected".to_string());
371 if mean_drift > 1.0 {
372 recommendations.push("Consider retraining or data rebalancing".to_string());
373 }
374 if std_drift > 1.0 {
375 recommendations.push("Check for changes in data preprocessing".to_string());
376 }
377 },
378 TensorDriftSeverity::Medium => {
379 recommendations.push("Moderate tensor drift detected".to_string());
380 recommendations.push("Monitor closely for further changes".to_string());
381 },
382 TensorDriftSeverity::Low => {
383 recommendations.push("Minimal tensor drift - within acceptable range".to_string());
384 },
385 }
386
387 recommendations
388 }
389}
390
391#[derive(Debug, Serialize, Deserialize)]
393pub struct TensorComparisonResult {
394 pub baseline_stats: TensorStatistics,
395 pub current_stats: TensorStatistics,
396 pub mean_drift: f32,
397 pub std_drift: f32,
398 pub distribution_shift: f32,
399 pub drift_severity: TensorDriftSeverity,
400 pub recommendations: Vec<String>,
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
405pub enum TensorDriftSeverity {
406 Low,
407 Medium,
408 High,
409}