1use anyhow::Result;
4use scirs2_core::ndarray::*; use scirs2_core::random::random; use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::fmt;
9use std::time::Instant;
10use uuid::Uuid;
11
12use crate::DebugConfig;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TensorStats {
17 pub shape: Vec<usize>,
18 pub dtype: String,
19 pub total_elements: usize,
20 pub mean: f64,
21 pub std: f64,
22 pub min: f64,
23 pub max: f64,
24 pub median: f64,
25 pub l1_norm: f64,
26 pub l2_norm: f64,
27 pub infinity_norm: f64,
28 pub nan_count: usize,
29 pub inf_count: usize,
30 pub zero_count: usize,
31 pub memory_usage_bytes: usize,
32 pub sparsity: f64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TensorDistribution {
38 pub histogram: Vec<(f64, usize)>,
39 pub percentiles: HashMap<String, f64>,
40 pub outliers: Vec<f64>,
41 pub skewness: f64,
42 pub kurtosis: f64,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TensorInfo {
48 pub id: Uuid,
49 pub name: String,
50 pub layer_name: Option<String>,
51 pub operation: Option<String>,
52 pub timestamp: chrono::DateTime<chrono::Utc>,
53 pub stats: TensorStats,
54 pub distribution: Option<TensorDistribution>,
55 pub gradient_stats: Option<TensorStats>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TensorComparison {
61 pub tensor1_id: Uuid,
62 pub tensor2_id: Uuid,
63 pub mse: f64,
64 pub mae: f64,
65 pub max_diff: f64,
66 pub cosine_similarity: f64,
67 pub correlation: f64,
68 pub shape_match: bool,
69 pub dtype_match: bool,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TensorTimeSeries {
75 pub tensor_id: Uuid,
76 pub timestamps: VecDeque<chrono::DateTime<chrono::Utc>>,
77 pub values: VecDeque<TensorStats>,
78 pub max_history: usize,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TensorDependency {
84 pub source_id: Uuid,
85 pub target_id: Uuid,
86 pub operation: String,
87 pub weight: f64,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub enum TensorLifecycleEvent {
93 Created { size_bytes: usize },
94 Modified { operation: String },
95 Accessed { access_type: String },
96 Destroyed,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct TensorLifecycle {
102 pub tensor_id: Uuid,
103 pub events: Vec<(chrono::DateTime<chrono::Utc>, TensorLifecycleEvent)>,
104 pub total_accesses: usize,
105 pub creation_time: chrono::DateTime<chrono::Utc>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct AdvancedTensorAnalysis {
111 pub spectral_analysis: Option<SpectralAnalysis>,
112 pub information_content: InformationContent,
113 pub stability_metrics: StabilityMetrics,
114 pub relationship_analysis: RelationshipAnalysis,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct SpectralAnalysis {
120 pub eigenvalues: Vec<f64>,
121 pub condition_number: f64,
122 pub rank: usize,
123 pub spectral_norm: f64,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct InformationContent {
129 pub entropy: f64,
130 pub mutual_information: f64,
131 pub effective_rank: f64,
132 pub compression_ratio: f64,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct StabilityMetrics {
138 pub numerical_stability: f64,
139 pub gradient_stability: f64,
140 pub perturbation_sensitivity: f64,
141 pub robustness_score: f64,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct RelationshipAnalysis {
147 pub cross_correlations: HashMap<Uuid, f64>,
148 pub dependency_strength: HashMap<Uuid, f64>,
149 pub causal_relationships: Vec<TensorDependency>,
150}
151
152#[derive(Debug)]
154pub struct TensorInspector {
155 config: DebugConfig,
156 tracked_tensors: HashMap<Uuid, TensorInfo>,
157 comparisons: Vec<TensorComparison>,
158 alerts: Vec<TensorAlert>,
159 time_series: HashMap<Uuid, TensorTimeSeries>,
161 dependencies: Vec<TensorDependency>,
162 lifecycles: HashMap<Uuid, TensorLifecycle>,
163 monitoring_enabled: bool,
164 last_analysis_time: Option<Instant>,
165}
166
167impl TensorInspector {
168 pub fn new(config: &DebugConfig) -> Self {
170 Self {
171 config: config.clone(),
172 tracked_tensors: HashMap::new(),
173 comparisons: Vec::new(),
174 alerts: Vec::new(),
175 time_series: HashMap::new(),
177 dependencies: Vec::new(),
178 lifecycles: HashMap::new(),
179 monitoring_enabled: false,
180 last_analysis_time: None,
181 }
182 }
183
184 pub async fn start(&mut self) -> Result<()> {
186 tracing::info!("Starting tensor inspector");
187 Ok(())
188 }
189
190 pub fn inspect_tensor<T>(
192 &mut self,
193 tensor: &ArrayD<T>,
194 name: &str,
195 layer_name: Option<&str>,
196 operation: Option<&str>,
197 ) -> Result<Uuid>
198 where
199 T: Clone + Into<f64> + fmt::Debug + 'static,
200 {
201 let id = Uuid::new_v4();
202
203 let values: Vec<f64> = tensor.iter().map(|x| x.clone().into()).collect();
205 let shape = tensor.shape().to_vec();
206
207 let stats = self.compute_tensor_stats(&values, &shape, std::mem::size_of::<T>())?;
208 let distribution = if self.should_compute_distribution() {
209 Some(self.compute_distribution(&values)?)
210 } else {
211 None
212 };
213
214 let tensor_info = TensorInfo {
215 id,
216 name: name.to_string(),
217 layer_name: layer_name.map(|s| s.to_string()),
218 operation: operation.map(|s| s.to_string()),
219 timestamp: chrono::Utc::now(),
220 stats,
221 distribution,
222 gradient_stats: None,
223 };
224
225 self.check_tensor_alerts(&tensor_info)?;
227
228 if self.tracked_tensors.len() < self.config.max_tracked_tensors {
230 self.tracked_tensors.insert(id, tensor_info.clone());
231 }
232
233 self.record_lifecycle_event(
235 id,
236 TensorLifecycleEvent::Created {
237 size_bytes: std::mem::size_of::<T>() * tensor.len(),
238 },
239 );
240
241 if self.monitoring_enabled {
242 if let Some(tensor_info) = self.tracked_tensors.get(&id) {
243 self.update_time_series(id, tensor_info.stats.clone());
244 }
245 }
246
247 Ok(id)
248 }
249
250 pub fn inspect_gradients<T>(&mut self, tensor_id: Uuid, gradients: &ArrayD<T>) -> Result<()>
252 where
253 T: Clone + Into<f64> + fmt::Debug + 'static,
254 {
255 let values: Vec<f64> = gradients.iter().map(|x| x.clone().into()).collect();
256 let shape = gradients.shape().to_vec();
257
258 let gradient_stats =
259 self.compute_tensor_stats(&values, &shape, std::mem::size_of::<T>())?;
260
261 if let Some(tensor_info) = self.tracked_tensors.get_mut(&tensor_id) {
262 tensor_info.gradient_stats = Some(gradient_stats);
263 }
264
265 let tensor_info_for_alerts = self
267 .tracked_tensors
268 .get(&tensor_id)
269 .map(|info| (info.id, info.name.clone(), info.gradient_stats.clone()));
270
271 if let Some((id, name, grad_stats)) = tensor_info_for_alerts {
272 self.check_gradient_alerts_with_data(id, &name, grad_stats)?;
273 }
274
275 Ok(())
276 }
277
278 pub fn compare_tensors(&mut self, id1: Uuid, id2: Uuid) -> Result<TensorComparison> {
280 let tensor1 = self
281 .tracked_tensors
282 .get(&id1)
283 .ok_or_else(|| anyhow::anyhow!("Tensor {} not found", id1))?;
284 let tensor2 = self
285 .tracked_tensors
286 .get(&id2)
287 .ok_or_else(|| anyhow::anyhow!("Tensor {} not found", id2))?;
288
289 let comparison = TensorComparison {
290 tensor1_id: id1,
291 tensor2_id: id2,
292 mse: self.compute_mse(&tensor1.stats, &tensor2.stats),
293 mae: self.compute_mae(&tensor1.stats, &tensor2.stats),
294 max_diff: (tensor1.stats.max - tensor2.stats.max).abs(),
295 cosine_similarity: self.compute_cosine_similarity(&tensor1.stats, &tensor2.stats),
296 correlation: self.compute_correlation(&tensor1.stats, &tensor2.stats),
297 shape_match: tensor1.stats.shape == tensor2.stats.shape,
298 dtype_match: tensor1.stats.dtype == tensor2.stats.dtype,
299 };
300
301 self.comparisons.push(comparison.clone());
302 Ok(comparison)
303 }
304
305 pub fn get_tensor_info(&self, id: Uuid) -> Option<&TensorInfo> {
307 self.tracked_tensors.get(&id)
308 }
309
310 pub fn get_all_tensors(&self) -> Vec<&TensorInfo> {
312 self.tracked_tensors.values().collect()
313 }
314
315 pub fn get_tensors_by_layer(&self, layer_name: &str) -> Vec<&TensorInfo> {
317 self.tracked_tensors
318 .values()
319 .filter(|info| info.layer_name.as_ref() == Some(&layer_name.to_string()))
320 .collect()
321 }
322
323 pub fn get_alerts(&self) -> &[TensorAlert] {
325 &self.alerts
326 }
327
328 pub fn clear(&mut self) {
330 self.tracked_tensors.clear();
331 self.comparisons.clear();
332 self.alerts.clear();
333 self.time_series.clear();
335 self.dependencies.clear();
336 self.lifecycles.clear();
337 self.last_analysis_time = None;
338 }
339
340 pub async fn generate_report(&self) -> Result<TensorInspectionReport> {
342 let total_tensors = self.tracked_tensors.len();
343 let tensors_with_issues = self
344 .tracked_tensors
345 .values()
346 .filter(|info| info.stats.nan_count > 0 || info.stats.inf_count > 0)
347 .count();
348
349 let memory_usage =
350 self.tracked_tensors.values().map(|info| info.stats.memory_usage_bytes).sum();
351
352 Ok(TensorInspectionReport {
353 total_tensors,
354 tensors_with_issues,
355 total_memory_usage: memory_usage,
356 alerts: self.alerts.clone(),
357 comparisons: self.comparisons.clone(),
358 summary_stats: self.compute_summary_stats(),
359 })
360 }
361
362 pub fn enable_monitoring(&mut self, enable: bool) {
366 self.monitoring_enabled = enable;
367 if enable {
368 tracing::info!("Real-time tensor monitoring enabled");
369 } else {
370 tracing::info!("Real-time tensor monitoring disabled");
371 }
372 }
373
374 pub fn track_dependency(
376 &mut self,
377 source_id: Uuid,
378 target_id: Uuid,
379 operation: &str,
380 weight: f64,
381 ) {
382 let dependency = TensorDependency {
383 source_id,
384 target_id,
385 operation: operation.to_string(),
386 weight,
387 };
388 self.dependencies.push(dependency);
389 }
390
391 pub fn record_lifecycle_event(&mut self, tensor_id: Uuid, event: TensorLifecycleEvent) {
393 let lifecycle = self.lifecycles.entry(tensor_id).or_insert_with(|| TensorLifecycle {
394 tensor_id,
395 events: Vec::new(),
396 total_accesses: 0,
397 creation_time: chrono::Utc::now(),
398 });
399
400 lifecycle.events.push((chrono::Utc::now(), event.clone()));
401
402 if matches!(event, TensorLifecycleEvent::Accessed { .. }) {
403 lifecycle.total_accesses += 1;
404 }
405 }
406
407 pub fn update_time_series(&mut self, tensor_id: Uuid, stats: TensorStats) {
409 if !self.monitoring_enabled {
410 return;
411 }
412
413 let time_series = self.time_series.entry(tensor_id).or_insert_with(|| TensorTimeSeries {
414 tensor_id,
415 timestamps: VecDeque::new(),
416 values: VecDeque::new(),
417 max_history: 1000, });
419
420 time_series.timestamps.push_back(chrono::Utc::now());
421 time_series.values.push_back(stats);
422
423 while time_series.timestamps.len() > time_series.max_history {
425 time_series.timestamps.pop_front();
426 time_series.values.pop_front();
427 }
428 }
429
430 pub fn perform_advanced_analysis<T>(&self, tensor: &ArrayD<T>) -> Result<AdvancedTensorAnalysis>
432 where
433 T: Clone + Into<f64> + fmt::Debug + 'static,
434 {
435 let values: Vec<f64> = tensor.iter().map(|x| x.clone().into()).collect();
436
437 Ok(AdvancedTensorAnalysis {
438 spectral_analysis: self.compute_spectral_analysis(&values, tensor.shape())?,
439 information_content: self.compute_information_content(&values)?,
440 stability_metrics: self.compute_stability_metrics(&values)?,
441 relationship_analysis: self.compute_relationship_analysis(&values)?,
442 })
443 }
444
445 pub fn detect_advanced_anomalies(&self, tensor_id: Uuid) -> Result<Vec<TensorAlert>> {
447 let mut alerts = Vec::new();
448
449 if let Some(time_series) = self.time_series.get(&tensor_id) {
450 if time_series.values.len() >= 10 {
452 let recent_mean =
453 time_series.values.iter().rev().take(5).map(|stats| stats.mean).sum::<f64>()
454 / 5.0;
455 let historical_mean =
456 time_series.values.iter().take(5).map(|stats| stats.mean).sum::<f64>() / 5.0;
457
458 let drift_ratio =
459 (recent_mean - historical_mean).abs() / historical_mean.abs().max(1e-8);
460
461 if drift_ratio > 0.5 {
462 if let Some(tensor_info) = self.tracked_tensors.get(&tensor_id) {
463 alerts.push(TensorAlert {
464 id: Uuid::new_v4(),
465 tensor_id,
466 tensor_name: tensor_info.name.clone(),
467 alert_type: TensorAlertType::ExtremeValues,
468 severity: AlertSeverity::Warning,
469 message: format!(
470 "Detected statistical drift in tensor '{}': {:.2}% change",
471 tensor_info.name,
472 drift_ratio * 100.0
473 ),
474 timestamp: chrono::Utc::now(),
475 });
476 }
477 }
478 }
479 }
480
481 Ok(alerts)
482 }
483
484 pub fn get_dependencies(&self) -> &[TensorDependency] {
486 &self.dependencies
487 }
488
489 pub fn get_lifecycle(&self, tensor_id: Uuid) -> Option<&TensorLifecycle> {
491 self.lifecycles.get(&tensor_id)
492 }
493
494 pub fn get_time_series(&self, tensor_id: Uuid) -> Option<&TensorTimeSeries> {
496 self.time_series.get(&tensor_id)
497 }
498
499 pub fn analyze_tensor_relationships(&self) -> HashMap<Uuid, Vec<Uuid>> {
501 let mut relationships = HashMap::new();
502
503 for dependency in &self.dependencies {
504 relationships
505 .entry(dependency.source_id)
506 .or_insert_with(Vec::new)
507 .push(dependency.target_id);
508 }
509
510 relationships
511 }
512
513 pub fn get_frequent_tensors(&self, min_accesses: usize) -> Vec<Uuid> {
515 self.lifecycles
516 .iter()
517 .filter(|(_, lifecycle)| lifecycle.total_accesses >= min_accesses)
518 .map(|(id, _)| *id)
519 .collect()
520 }
521
522 fn compute_tensor_stats(
525 &self,
526 values: &[f64],
527 shape: &[usize],
528 element_size: usize,
529 ) -> Result<TensorStats> {
530 let total_elements = values.len();
531 let mean = values.iter().sum::<f64>() / total_elements as f64;
532
533 let variance =
534 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / total_elements as f64;
535 let std = variance.sqrt();
536
537 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
538 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
539
540 let mut sorted_values = values.to_vec();
541 sorted_values.retain(|x| !x.is_nan());
543 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544 let median = if sorted_values.is_empty() {
545 f64::NAN
546 } else if sorted_values.len() % 2 == 0 {
547 (sorted_values[sorted_values.len() / 2 - 1] + sorted_values[sorted_values.len() / 2])
548 / 2.0
549 } else {
550 sorted_values[sorted_values.len() / 2]
551 };
552
553 let l1_norm = values.iter().map(|x| x.abs()).sum::<f64>();
554 let l2_norm = values.iter().map(|x| x * x).sum::<f64>().sqrt();
555 let infinity_norm = values.iter().map(|x| x.abs()).fold(0.0, f64::max);
556
557 let nan_count = values.iter().filter(|x| x.is_nan()).count();
558 let inf_count = values.iter().filter(|x| x.is_infinite()).count();
559 let zero_count = values.iter().filter(|x| **x == 0.0).count();
560
561 let memory_usage_bytes = total_elements * element_size;
562 let sparsity = zero_count as f64 / total_elements as f64;
563
564 Ok(TensorStats {
565 shape: shape.to_vec(),
566 dtype: "f64".to_string(), total_elements,
568 mean,
569 std,
570 min,
571 max,
572 median,
573 l1_norm,
574 l2_norm,
575 infinity_norm,
576 nan_count,
577 inf_count,
578 zero_count,
579 memory_usage_bytes,
580 sparsity,
581 })
582 }
583
584 fn compute_distribution(&self, values: &[f64]) -> Result<TensorDistribution> {
585 let num_bins = 50;
587 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
588 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
589 let bin_width = (max - min) / num_bins as f64;
590
591 let mut histogram = vec![(0.0, 0); num_bins];
592 for &value in values {
593 if !value.is_finite() {
594 continue;
595 }
596 let bin_idx = ((value - min) / bin_width).floor() as usize;
597 let bin_idx = bin_idx.min(num_bins - 1);
598 histogram[bin_idx].0 = min + bin_idx as f64 * bin_width;
599 histogram[bin_idx].1 += 1;
600 }
601
602 let mut sorted_values =
604 values.iter().cloned().filter(|x| x.is_finite()).collect::<Vec<_>>();
605 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
606
607 let mut percentiles = HashMap::new();
608 for &p in &[5.0, 25.0, 50.0, 75.0, 95.0, 99.0] {
609 let idx = ((p / 100.0) * (sorted_values.len() - 1) as f64) as usize;
610 percentiles.insert(format!("p{}", p as u8), sorted_values[idx]);
611 }
612
613 let q1 = percentiles["p25"];
615 let q3 = percentiles["p75"];
616 let iqr = q3 - q1;
617 let lower_bound = q1 - 1.5 * iqr;
618 let upper_bound = q3 + 1.5 * iqr;
619
620 let outliers: Vec<f64> = sorted_values
621 .iter()
622 .cloned()
623 .filter(|&x| x < lower_bound || x > upper_bound)
624 .take(100) .collect();
626
627 let mean = sorted_values.iter().sum::<f64>() / sorted_values.len() as f64;
629 let variance = sorted_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
630 / sorted_values.len() as f64;
631 let std = variance.sqrt();
632
633 let skewness = if std > 0.0 {
634 sorted_values.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f64>()
635 / sorted_values.len() as f64
636 } else {
637 0.0
638 };
639
640 let kurtosis = if std > 0.0 {
641 sorted_values.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f64>()
642 / sorted_values.len() as f64
643 - 3.0
644 } else {
645 0.0
646 };
647
648 Ok(TensorDistribution {
649 histogram,
650 percentiles,
651 outliers,
652 skewness,
653 kurtosis,
654 })
655 }
656
657 fn should_compute_distribution(&self) -> bool {
658 self.config.sampling_rate >= 1.0
659 || (self.config.sampling_rate > 0.0 && random::<f32>() < self.config.sampling_rate)
660 }
661
662 fn check_tensor_alerts(&mut self, tensor_info: &TensorInfo) -> Result<()> {
663 if tensor_info.stats.nan_count > 0 {
665 self.alerts.push(TensorAlert {
666 id: Uuid::new_v4(),
667 tensor_id: tensor_info.id,
668 tensor_name: tensor_info.name.clone(),
669 alert_type: TensorAlertType::NaNValues,
670 severity: AlertSeverity::Critical,
671 message: format!(
672 "Found {} NaN values in tensor '{}'",
673 tensor_info.stats.nan_count, tensor_info.name
674 ),
675 timestamp: chrono::Utc::now(),
676 });
677 }
678
679 if tensor_info.stats.inf_count > 0 {
681 self.alerts.push(TensorAlert {
682 id: Uuid::new_v4(),
683 tensor_id: tensor_info.id,
684 tensor_name: tensor_info.name.clone(),
685 alert_type: TensorAlertType::InfiniteValues,
686 severity: AlertSeverity::Critical,
687 message: format!(
688 "Found {} infinite values in tensor '{}'",
689 tensor_info.stats.inf_count, tensor_info.name
690 ),
691 timestamp: chrono::Utc::now(),
692 });
693 }
694
695 if tensor_info.stats.max.abs() > 1e10 || tensor_info.stats.min.abs() > 1e10 {
697 self.alerts.push(TensorAlert {
698 id: Uuid::new_v4(),
699 tensor_id: tensor_info.id,
700 tensor_name: tensor_info.name.clone(),
701 alert_type: TensorAlertType::ExtremeValues,
702 severity: AlertSeverity::Warning,
703 message: format!(
704 "Extreme values detected in tensor '{}': min={:.2e}, max={:.2e}",
705 tensor_info.name, tensor_info.stats.min, tensor_info.stats.max
706 ),
707 timestamp: chrono::Utc::now(),
708 });
709 }
710
711 Ok(())
712 }
713
714 fn check_gradient_alerts_with_data(
715 &mut self,
716 tensor_id: Uuid,
717 tensor_name: &str,
718 grad_stats: Option<TensorStats>,
719 ) -> Result<()> {
720 if let Some(ref stats) = grad_stats {
721 if stats.l2_norm < 1e-8 {
723 self.alerts.push(TensorAlert {
724 id: Uuid::new_v4(),
725 tensor_id,
726 tensor_name: tensor_name.to_string(),
727 alert_type: TensorAlertType::VanishingGradients,
728 severity: AlertSeverity::Warning,
729 message: format!(
730 "Vanishing gradients detected in '{}': L2 norm = {:.2e}",
731 tensor_name, stats.l2_norm
732 ),
733 timestamp: chrono::Utc::now(),
734 });
735 }
736
737 if stats.l2_norm > 100.0 {
739 self.alerts.push(TensorAlert {
740 id: Uuid::new_v4(),
741 tensor_id,
742 tensor_name: tensor_name.to_string(),
743 alert_type: TensorAlertType::ExplodingGradients,
744 severity: AlertSeverity::Critical,
745 message: format!(
746 "Exploding gradients detected in '{}': L2 norm = {:.2e}",
747 tensor_name, stats.l2_norm
748 ),
749 timestamp: chrono::Utc::now(),
750 });
751 }
752 }
753
754 Ok(())
755 }
756
757 fn compute_mse(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
758 (stats1.mean - stats2.mean).powi(2)
760 }
761
762 fn compute_mae(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
763 (stats1.mean - stats2.mean).abs()
765 }
766
767 fn compute_cosine_similarity(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
768 if stats1.l2_norm == 0.0 || stats2.l2_norm == 0.0 {
770 0.0
771 } else {
772 (stats1.mean * stats2.mean) / (stats1.l2_norm * stats2.l2_norm)
773 }
774 }
775
776 fn compute_correlation(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
777 if stats1.std == 0.0 || stats2.std == 0.0 {
779 0.0
780 } else {
781 0.5 }
783 }
784
785 fn compute_spectral_analysis(
788 &self,
789 values: &[f64],
790 shape: &[usize],
791 ) -> Result<Option<SpectralAnalysis>> {
792 if shape.len() != 2 || values.len() < 4 {
794 return Ok(None);
795 }
796
797 let rows = shape[0];
798 let cols = shape[1];
799
800 let largest_eigenvalue = self.power_iteration(values, rows, cols)?;
802
803 let frobenius_norm = (values.iter().map(|x| x * x).sum::<f64>()).sqrt();
805 let condition_number = if frobenius_norm > 1e-12 {
806 largest_eigenvalue / frobenius_norm.max(1e-12)
807 } else {
808 f64::INFINITY
809 };
810
811 let rank = values.iter().filter(|&&x| x.abs() > 1e-10).count().min(rows.min(cols));
813
814 Ok(Some(SpectralAnalysis {
815 eigenvalues: vec![largest_eigenvalue], condition_number,
817 rank,
818 spectral_norm: largest_eigenvalue,
819 }))
820 }
821
822 fn power_iteration(&self, matrix: &[f64], rows: usize, cols: usize) -> Result<f64> {
823 if rows != cols {
824 return Ok(0.0); }
826
827 let n = rows;
828 let mut x = vec![1.0; n];
829 let mut lambda = 0.0;
830
831 for _ in 0..10 {
833 let mut ax = vec![0.0; n];
834
835 for i in 0..n {
837 for j in 0..n {
838 ax[i] += matrix[i * n + j] * x[j];
839 }
840 }
841
842 let dot_ax_x: f64 = ax.iter().zip(x.iter()).map(|(a, b)| a * b).sum();
844 let dot_x_x: f64 = x.iter().map(|a| a * a).sum();
845
846 if dot_x_x > 1e-12 {
847 lambda = dot_ax_x / dot_x_x;
848 }
849
850 let norm = (ax.iter().map(|a| a * a).sum::<f64>()).sqrt();
852 if norm > 1e-12 {
853 x = ax.iter().map(|a| a / norm).collect();
854 }
855 }
856
857 Ok(lambda.abs())
858 }
859
860 fn compute_information_content(&self, values: &[f64]) -> Result<InformationContent> {
861 let mut histogram = std::collections::HashMap::new();
863 let quantization_levels = 100;
864 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
865 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
866 let range = max_val - min_val;
867
868 if range > 1e-12 {
869 for &value in values {
870 let bucket = ((value - min_val) / range * quantization_levels as f64) as usize;
871 let bucket = bucket.min(quantization_levels - 1);
872 *histogram.entry(bucket).or_insert(0) += 1;
873 }
874 }
875
876 let total_count = values.len() as f64;
877 let entropy = if total_count > 0.0 {
878 histogram
879 .values()
880 .map(|&count| {
881 let p = count as f64 / total_count;
882 if p > 0.0 {
883 -p * p.log2()
884 } else {
885 0.0
886 }
887 })
888 .sum()
889 } else {
890 0.0
891 };
892
893 let effective_rank = if entropy > 0.0 { 2.0_f64.powf(entropy) } else { 1.0 };
895
896 let mut sorted_values = values.to_vec();
898 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
899 sorted_values.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
900 let unique_values = sorted_values.len();
901 let compression_ratio = unique_values as f64 / values.len() as f64;
902
903 Ok(InformationContent {
904 entropy,
905 mutual_information: 0.0, effective_rank,
907 compression_ratio,
908 })
909 }
910
911 fn compute_stability_metrics(&self, values: &[f64]) -> Result<StabilityMetrics> {
912 let mean = values.iter().sum::<f64>() / values.len() as f64;
914 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
915 let std_dev = variance.sqrt();
916
917 let numerical_stability = if std_dev > 1e-12 {
918 1.0 / (1.0 + std_dev / mean.abs().max(1e-12))
919 } else {
920 1.0
921 };
922
923 let max_abs = values.iter().map(|x| x.abs()).fold(0.0, f64::max);
925 let perturbation_sensitivity = if max_abs > 1e-12 { std_dev / max_abs } else { 0.0 };
926
927 let robustness_score = numerical_stability * (1.0 - perturbation_sensitivity.min(1.0));
929
930 Ok(StabilityMetrics {
931 numerical_stability,
932 gradient_stability: 0.8, perturbation_sensitivity,
934 robustness_score,
935 })
936 }
937
938 fn compute_relationship_analysis(&self, values: &[f64]) -> Result<RelationshipAnalysis> {
939 let mut cross_correlations = HashMap::new();
941 let mut dependency_strength = HashMap::new();
942 let causal_relationships = Vec::new(); for tensor_info in self.tracked_tensors.values() {
946 if tensor_info.stats.total_elements > 0 {
947 let correlation =
948 self.compute_simple_correlation(values, &[tensor_info.stats.mean]);
949 cross_correlations.insert(tensor_info.id, correlation);
950 dependency_strength.insert(tensor_info.id, correlation.abs());
951 }
952 }
953
954 Ok(RelationshipAnalysis {
955 cross_correlations,
956 dependency_strength,
957 causal_relationships,
958 })
959 }
960
961 fn compute_simple_correlation(&self, values1: &[f64], values2: &[f64]) -> f64 {
962 if values1.is_empty() || values2.is_empty() {
963 return 0.0;
964 }
965
966 let mean1 = values1.iter().sum::<f64>() / values1.len() as f64;
967 let mean2 = values2.iter().sum::<f64>() / values2.len() as f64;
968
969 let min_len = values1.len().min(values2.len());
970 let numerator: f64 = values1
971 .iter()
972 .zip(values2.iter())
973 .take(min_len)
974 .map(|(x, y)| (x - mean1) * (y - mean2))
975 .sum();
976
977 let sum_sq1: f64 = values1.iter().take(min_len).map(|x| (x - mean1).powi(2)).sum();
978 let sum_sq2: f64 = values2.iter().take(min_len).map(|y| (y - mean2).powi(2)).sum();
979
980 let denominator = (sum_sq1 * sum_sq2).sqrt();
981 if denominator > 1e-12 {
982 numerator / denominator
983 } else {
984 0.0
985 }
986 }
987
988 fn compute_summary_stats(&self) -> HashMap<String, f64> {
989 let mut stats = HashMap::new();
990
991 if !self.tracked_tensors.is_empty() {
992 let values: Vec<f64> = self.tracked_tensors.values().map(|t| t.stats.mean).collect();
993 stats.insert(
994 "mean_of_means".to_string(),
995 values.iter().sum::<f64>() / values.len() as f64,
996 );
997
998 let total_memory: usize =
999 self.tracked_tensors.values().map(|t| t.stats.memory_usage_bytes).sum();
1000 stats.insert(
1001 "total_memory_mb".to_string(),
1002 total_memory as f64 / (1024.0 * 1024.0),
1003 );
1004
1005 let avg_sparsity: f64 =
1006 self.tracked_tensors.values().map(|t| t.stats.sparsity).sum::<f64>()
1007 / self.tracked_tensors.len() as f64;
1008 stats.insert("avg_sparsity".to_string(), avg_sparsity);
1009
1010 stats.insert(
1012 "total_dependencies".to_string(),
1013 self.dependencies.len() as f64,
1014 );
1015 stats.insert(
1016 "monitored_tensors".to_string(),
1017 self.time_series.len() as f64,
1018 );
1019 stats.insert(
1020 "active_lifecycles".to_string(),
1021 self.lifecycles.len() as f64,
1022 );
1023 }
1024
1025 stats
1026 }
1027}
1028
1029#[derive(Debug, Clone, Serialize, Deserialize)]
1031pub enum TensorAlertType {
1032 NaNValues,
1033 InfiniteValues,
1034 ExtremeValues,
1035 VanishingGradients,
1036 ExplodingGradients,
1037 MemoryUsage,
1038 ShapeMismatch,
1039}
1040
1041#[derive(Debug, Clone, Serialize, Deserialize)]
1043pub enum AlertSeverity {
1044 Info,
1045 Warning,
1046 Critical,
1047}
1048
1049#[derive(Debug, Clone, Serialize, Deserialize)]
1051pub struct TensorAlert {
1052 pub id: Uuid,
1053 pub tensor_id: Uuid,
1054 pub tensor_name: String,
1055 pub alert_type: TensorAlertType,
1056 pub severity: AlertSeverity,
1057 pub message: String,
1058 pub timestamp: chrono::DateTime<chrono::Utc>,
1059}
1060
1061#[derive(Debug, Clone, Serialize, Deserialize)]
1063pub struct TensorInspectionReport {
1064 pub total_tensors: usize,
1065 pub tensors_with_issues: usize,
1066 pub total_memory_usage: usize,
1067 pub alerts: Vec<TensorAlert>,
1068 pub comparisons: Vec<TensorComparison>,
1069 pub summary_stats: HashMap<String, f64>,
1070}
1071
1072impl TensorInspectionReport {
1073 pub fn has_nan_values(&self) -> bool {
1074 self.alerts.iter().any(|a| matches!(a.alert_type, TensorAlertType::NaNValues))
1075 }
1076
1077 pub fn has_inf_values(&self) -> bool {
1078 self.alerts
1079 .iter()
1080 .any(|a| matches!(a.alert_type, TensorAlertType::InfiniteValues))
1081 }
1082
1083 pub fn total_nan_count(&self) -> usize {
1084 self.alerts
1085 .iter()
1086 .filter(|a| matches!(a.alert_type, TensorAlertType::NaNValues))
1087 .count()
1088 }
1089
1090 pub fn total_inf_count(&self) -> usize {
1091 self.alerts
1092 .iter()
1093 .filter(|a| matches!(a.alert_type, TensorAlertType::InfiniteValues))
1094 .count()
1095 }
1096
1097 pub fn has_critical_alerts(&self) -> bool {
1098 self.alerts.iter().any(|a| matches!(a.severity, AlertSeverity::Critical))
1099 }
1100}