1use anyhow::Result;
4use scirs2_core::ndarray::*; use scirs2_core::random::random; use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::fmt;
9use std::time::Instant;
10use uuid::Uuid;
11
12use crate::DebugConfig;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TensorStats {
17 pub shape: Vec<usize>,
18 pub dtype: String,
19 pub total_elements: usize,
20 pub mean: f64,
21 pub std: f64,
22 pub min: f64,
23 pub max: f64,
24 pub median: f64,
25 pub l1_norm: f64,
26 pub l2_norm: f64,
27 pub infinity_norm: f64,
28 pub nan_count: usize,
29 pub inf_count: usize,
30 pub zero_count: usize,
31 pub memory_usage_bytes: usize,
32 pub sparsity: f64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TensorDistribution {
38 pub histogram: Vec<(f64, usize)>,
39 pub percentiles: HashMap<String, f64>,
40 pub outliers: Vec<f64>,
41 pub skewness: f64,
42 pub kurtosis: f64,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TensorInfo {
48 pub id: Uuid,
49 pub name: String,
50 pub layer_name: Option<String>,
51 pub operation: Option<String>,
52 pub timestamp: chrono::DateTime<chrono::Utc>,
53 pub stats: TensorStats,
54 pub distribution: Option<TensorDistribution>,
55 pub gradient_stats: Option<TensorStats>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TensorComparison {
61 pub tensor1_id: Uuid,
62 pub tensor2_id: Uuid,
63 pub mse: f64,
64 pub mae: f64,
65 pub max_diff: f64,
66 pub cosine_similarity: f64,
67 pub correlation: f64,
68 pub shape_match: bool,
69 pub dtype_match: bool,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TensorTimeSeries {
75 pub tensor_id: Uuid,
76 pub timestamps: VecDeque<chrono::DateTime<chrono::Utc>>,
77 pub values: VecDeque<TensorStats>,
78 pub max_history: usize,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TensorDependency {
84 pub source_id: Uuid,
85 pub target_id: Uuid,
86 pub operation: String,
87 pub weight: f64,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub enum TensorLifecycleEvent {
93 Created { size_bytes: usize },
94 Modified { operation: String },
95 Accessed { access_type: String },
96 Destroyed,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct TensorLifecycle {
102 pub tensor_id: Uuid,
103 pub events: Vec<(chrono::DateTime<chrono::Utc>, TensorLifecycleEvent)>,
104 pub total_accesses: usize,
105 pub creation_time: chrono::DateTime<chrono::Utc>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct AdvancedTensorAnalysis {
111 pub spectral_analysis: Option<SpectralAnalysis>,
112 pub information_content: InformationContent,
113 pub stability_metrics: StabilityMetrics,
114 pub relationship_analysis: RelationshipAnalysis,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct SpectralAnalysis {
120 pub eigenvalues: Vec<f64>,
121 pub condition_number: f64,
122 pub rank: usize,
123 pub spectral_norm: f64,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct InformationContent {
129 pub entropy: f64,
130 pub mutual_information: f64,
131 pub effective_rank: f64,
132 pub compression_ratio: f64,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct StabilityMetrics {
138 pub numerical_stability: f64,
139 pub gradient_stability: f64,
140 pub perturbation_sensitivity: f64,
141 pub robustness_score: f64,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct RelationshipAnalysis {
147 pub cross_correlations: HashMap<Uuid, f64>,
148 pub dependency_strength: HashMap<Uuid, f64>,
149 pub causal_relationships: Vec<TensorDependency>,
150}
151
152#[derive(Debug)]
154pub struct TensorInspector {
155 config: DebugConfig,
156 tracked_tensors: HashMap<Uuid, TensorInfo>,
157 comparisons: Vec<TensorComparison>,
158 alerts: Vec<TensorAlert>,
159 time_series: HashMap<Uuid, TensorTimeSeries>,
161 dependencies: Vec<TensorDependency>,
162 lifecycles: HashMap<Uuid, TensorLifecycle>,
163 monitoring_enabled: bool,
164 last_analysis_time: Option<Instant>,
165}
166
167impl TensorInspector {
168 pub fn new(config: &DebugConfig) -> Self {
170 Self {
171 config: config.clone(),
172 tracked_tensors: HashMap::new(),
173 comparisons: Vec::new(),
174 alerts: Vec::new(),
175 time_series: HashMap::new(),
177 dependencies: Vec::new(),
178 lifecycles: HashMap::new(),
179 monitoring_enabled: false,
180 last_analysis_time: None,
181 }
182 }
183
184 pub async fn start(&mut self) -> Result<()> {
186 tracing::info!("Starting tensor inspector");
187 Ok(())
188 }
189
190 pub fn inspect_tensor<T>(
192 &mut self,
193 tensor: &ArrayD<T>,
194 name: &str,
195 layer_name: Option<&str>,
196 operation: Option<&str>,
197 ) -> Result<Uuid>
198 where
199 T: Clone + Into<f64> + fmt::Debug + 'static,
200 {
201 let id = Uuid::new_v4();
202
203 let values: Vec<f64> = tensor.iter().map(|x| x.clone().into()).collect();
205 let shape = tensor.shape().to_vec();
206
207 let stats = self.compute_tensor_stats(&values, &shape, std::mem::size_of::<T>())?;
208 let distribution = if self.should_compute_distribution() {
209 Some(self.compute_distribution(&values)?)
210 } else {
211 None
212 };
213
214 let tensor_info = TensorInfo {
215 id,
216 name: name.to_string(),
217 layer_name: layer_name.map(|s| s.to_string()),
218 operation: operation.map(|s| s.to_string()),
219 timestamp: chrono::Utc::now(),
220 stats,
221 distribution,
222 gradient_stats: None,
223 };
224
225 self.check_tensor_alerts(&tensor_info)?;
227
228 if self.tracked_tensors.len() < self.config.max_tracked_tensors {
230 self.tracked_tensors.insert(id, tensor_info.clone());
231 }
232
233 self.record_lifecycle_event(
235 id,
236 TensorLifecycleEvent::Created {
237 size_bytes: std::mem::size_of::<T>() * tensor.len(),
238 },
239 );
240
241 if self.monitoring_enabled {
242 if let Some(tensor_info) = self.tracked_tensors.get(&id) {
243 self.update_time_series(id, tensor_info.stats.clone());
244 }
245 }
246
247 Ok(id)
248 }
249
250 pub fn inspect_gradients<T>(&mut self, tensor_id: Uuid, gradients: &ArrayD<T>) -> Result<()>
252 where
253 T: Clone + Into<f64> + fmt::Debug + 'static,
254 {
255 let values: Vec<f64> = gradients.iter().map(|x| x.clone().into()).collect();
256 let shape = gradients.shape().to_vec();
257
258 let gradient_stats =
259 self.compute_tensor_stats(&values, &shape, std::mem::size_of::<T>())?;
260
261 if let Some(tensor_info) = self.tracked_tensors.get_mut(&tensor_id) {
262 tensor_info.gradient_stats = Some(gradient_stats);
263 }
264
265 let tensor_info_for_alerts = self
267 .tracked_tensors
268 .get(&tensor_id)
269 .map(|info| (info.id, info.name.clone(), info.gradient_stats.clone()));
270
271 if let Some((id, name, grad_stats)) = tensor_info_for_alerts {
272 self.check_gradient_alerts_with_data(id, &name, grad_stats)?;
273 }
274
275 Ok(())
276 }
277
278 pub fn compare_tensors(&mut self, id1: Uuid, id2: Uuid) -> Result<TensorComparison> {
280 let tensor1 = self
281 .tracked_tensors
282 .get(&id1)
283 .ok_or_else(|| anyhow::anyhow!("Tensor {} not found", id1))?;
284 let tensor2 = self
285 .tracked_tensors
286 .get(&id2)
287 .ok_or_else(|| anyhow::anyhow!("Tensor {} not found", id2))?;
288
289 let comparison = TensorComparison {
290 tensor1_id: id1,
291 tensor2_id: id2,
292 mse: self.compute_mse(&tensor1.stats, &tensor2.stats),
293 mae: self.compute_mae(&tensor1.stats, &tensor2.stats),
294 max_diff: (tensor1.stats.max - tensor2.stats.max).abs(),
295 cosine_similarity: self.compute_cosine_similarity(&tensor1.stats, &tensor2.stats),
296 correlation: self.compute_correlation(&tensor1.stats, &tensor2.stats),
297 shape_match: tensor1.stats.shape == tensor2.stats.shape,
298 dtype_match: tensor1.stats.dtype == tensor2.stats.dtype,
299 };
300
301 self.comparisons.push(comparison.clone());
302 Ok(comparison)
303 }
304
305 pub fn get_tensor_info(&self, id: Uuid) -> Option<&TensorInfo> {
307 self.tracked_tensors.get(&id)
308 }
309
310 pub fn get_all_tensors(&self) -> Vec<&TensorInfo> {
312 self.tracked_tensors.values().collect()
313 }
314
315 pub fn get_tensors_by_layer(&self, layer_name: &str) -> Vec<&TensorInfo> {
317 self.tracked_tensors
318 .values()
319 .filter(|info| info.layer_name.as_ref() == Some(&layer_name.to_string()))
320 .collect()
321 }
322
323 pub fn get_alerts(&self) -> &[TensorAlert] {
325 &self.alerts
326 }
327
328 pub fn clear(&mut self) {
330 self.tracked_tensors.clear();
331 self.comparisons.clear();
332 self.alerts.clear();
333 self.time_series.clear();
335 self.dependencies.clear();
336 self.lifecycles.clear();
337 self.last_analysis_time = None;
338 }
339
340 pub async fn generate_report(&self) -> Result<TensorInspectionReport> {
342 let total_tensors = self.tracked_tensors.len();
343 let tensors_with_issues = self
344 .tracked_tensors
345 .values()
346 .filter(|info| info.stats.nan_count > 0 || info.stats.inf_count > 0)
347 .count();
348
349 let memory_usage =
350 self.tracked_tensors.values().map(|info| info.stats.memory_usage_bytes).sum();
351
352 Ok(TensorInspectionReport {
353 total_tensors,
354 tensors_with_issues,
355 total_memory_usage: memory_usage,
356 alerts: self.alerts.clone(),
357 comparisons: self.comparisons.clone(),
358 summary_stats: self.compute_summary_stats(),
359 })
360 }
361
362 pub fn enable_monitoring(&mut self, enable: bool) {
366 self.monitoring_enabled = enable;
367 if enable {
368 tracing::info!("Real-time tensor monitoring enabled");
369 } else {
370 tracing::info!("Real-time tensor monitoring disabled");
371 }
372 }
373
374 pub fn track_dependency(
376 &mut self,
377 source_id: Uuid,
378 target_id: Uuid,
379 operation: &str,
380 weight: f64,
381 ) {
382 let dependency = TensorDependency {
383 source_id,
384 target_id,
385 operation: operation.to_string(),
386 weight,
387 };
388 self.dependencies.push(dependency);
389 }
390
391 pub fn record_lifecycle_event(&mut self, tensor_id: Uuid, event: TensorLifecycleEvent) {
393 let lifecycle = self.lifecycles.entry(tensor_id).or_insert_with(|| TensorLifecycle {
394 tensor_id,
395 events: Vec::new(),
396 total_accesses: 0,
397 creation_time: chrono::Utc::now(),
398 });
399
400 lifecycle.events.push((chrono::Utc::now(), event.clone()));
401
402 if matches!(event, TensorLifecycleEvent::Accessed { .. }) {
403 lifecycle.total_accesses += 1;
404 }
405 }
406
407 pub fn update_time_series(&mut self, tensor_id: Uuid, stats: TensorStats) {
409 if !self.monitoring_enabled {
410 return;
411 }
412
413 let time_series = self.time_series.entry(tensor_id).or_insert_with(|| TensorTimeSeries {
414 tensor_id,
415 timestamps: VecDeque::new(),
416 values: VecDeque::new(),
417 max_history: 1000, });
419
420 time_series.timestamps.push_back(chrono::Utc::now());
421 time_series.values.push_back(stats);
422
423 while time_series.timestamps.len() > time_series.max_history {
425 time_series.timestamps.pop_front();
426 time_series.values.pop_front();
427 }
428 }
429
430 pub fn perform_advanced_analysis<T>(&self, tensor: &ArrayD<T>) -> Result<AdvancedTensorAnalysis>
432 where
433 T: Clone + Into<f64> + fmt::Debug + 'static,
434 {
435 let values: Vec<f64> = tensor.iter().map(|x| x.clone().into()).collect();
436
437 Ok(AdvancedTensorAnalysis {
438 spectral_analysis: self.compute_spectral_analysis(&values, tensor.shape())?,
439 information_content: self.compute_information_content(&values)?,
440 stability_metrics: self.compute_stability_metrics(&values)?,
441 relationship_analysis: self.compute_relationship_analysis(&values)?,
442 })
443 }
444
445 pub fn detect_advanced_anomalies(&self, tensor_id: Uuid) -> Result<Vec<TensorAlert>> {
447 let mut alerts = Vec::new();
448
449 if let Some(time_series) = self.time_series.get(&tensor_id) {
450 if time_series.values.len() >= 10 {
452 let recent_mean =
453 time_series.values.iter().rev().take(5).map(|stats| stats.mean).sum::<f64>()
454 / 5.0;
455 let historical_mean =
456 time_series.values.iter().take(5).map(|stats| stats.mean).sum::<f64>() / 5.0;
457
458 let drift_ratio =
459 (recent_mean - historical_mean).abs() / historical_mean.abs().max(1e-8);
460
461 if drift_ratio > 0.5 {
462 if let Some(tensor_info) = self.tracked_tensors.get(&tensor_id) {
463 alerts.push(TensorAlert {
464 id: Uuid::new_v4(),
465 tensor_id,
466 tensor_name: tensor_info.name.clone(),
467 alert_type: TensorAlertType::ExtremeValues,
468 severity: AlertSeverity::Warning,
469 message: format!(
470 "Detected statistical drift in tensor '{}': {:.2}% change",
471 tensor_info.name,
472 drift_ratio * 100.0
473 ),
474 timestamp: chrono::Utc::now(),
475 });
476 }
477 }
478 }
479 }
480
481 Ok(alerts)
482 }
483
484 pub fn get_dependencies(&self) -> &[TensorDependency] {
486 &self.dependencies
487 }
488
489 pub fn get_lifecycle(&self, tensor_id: Uuid) -> Option<&TensorLifecycle> {
491 self.lifecycles.get(&tensor_id)
492 }
493
494 pub fn get_time_series(&self, tensor_id: Uuid) -> Option<&TensorTimeSeries> {
496 self.time_series.get(&tensor_id)
497 }
498
499 pub fn analyze_tensor_relationships(&self) -> HashMap<Uuid, Vec<Uuid>> {
501 let mut relationships = HashMap::new();
502
503 for dependency in &self.dependencies {
504 relationships
505 .entry(dependency.source_id)
506 .or_insert_with(Vec::new)
507 .push(dependency.target_id);
508 }
509
510 relationships
511 }
512
513 pub fn get_frequent_tensors(&self, min_accesses: usize) -> Vec<Uuid> {
515 self.lifecycles
516 .iter()
517 .filter(|(_, lifecycle)| lifecycle.total_accesses >= min_accesses)
518 .map(|(id, _)| *id)
519 .collect()
520 }
521
522 fn compute_tensor_stats(
525 &self,
526 values: &[f64],
527 shape: &[usize],
528 element_size: usize,
529 ) -> Result<TensorStats> {
530 let total_elements = values.len();
531 let mean = values.iter().sum::<f64>() / total_elements as f64;
532
533 let variance =
534 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / total_elements as f64;
535 let std = variance.sqrt();
536
537 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
538 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
539
540 let mut sorted_values = values.to_vec();
541 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
542 let median = if sorted_values.len() % 2 == 0 {
543 (sorted_values[sorted_values.len() / 2 - 1] + sorted_values[sorted_values.len() / 2])
544 / 2.0
545 } else {
546 sorted_values[sorted_values.len() / 2]
547 };
548
549 let l1_norm = values.iter().map(|x| x.abs()).sum::<f64>();
550 let l2_norm = values.iter().map(|x| x * x).sum::<f64>().sqrt();
551 let infinity_norm = values.iter().map(|x| x.abs()).fold(0.0, f64::max);
552
553 let nan_count = values.iter().filter(|x| x.is_nan()).count();
554 let inf_count = values.iter().filter(|x| x.is_infinite()).count();
555 let zero_count = values.iter().filter(|x| **x == 0.0).count();
556
557 let memory_usage_bytes = total_elements * element_size;
558 let sparsity = zero_count as f64 / total_elements as f64;
559
560 Ok(TensorStats {
561 shape: shape.to_vec(),
562 dtype: "f64".to_string(), total_elements,
564 mean,
565 std,
566 min,
567 max,
568 median,
569 l1_norm,
570 l2_norm,
571 infinity_norm,
572 nan_count,
573 inf_count,
574 zero_count,
575 memory_usage_bytes,
576 sparsity,
577 })
578 }
579
580 fn compute_distribution(&self, values: &[f64]) -> Result<TensorDistribution> {
581 let num_bins = 50;
583 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
584 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
585 let bin_width = (max - min) / num_bins as f64;
586
587 let mut histogram = vec![(0.0, 0); num_bins];
588 for &value in values {
589 if !value.is_finite() {
590 continue;
591 }
592 let bin_idx = ((value - min) / bin_width).floor() as usize;
593 let bin_idx = bin_idx.min(num_bins - 1);
594 histogram[bin_idx].0 = min + bin_idx as f64 * bin_width;
595 histogram[bin_idx].1 += 1;
596 }
597
598 let mut sorted_values =
600 values.iter().cloned().filter(|x| x.is_finite()).collect::<Vec<_>>();
601 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
602
603 let mut percentiles = HashMap::new();
604 for &p in &[5.0, 25.0, 50.0, 75.0, 95.0, 99.0] {
605 let idx = ((p / 100.0) * (sorted_values.len() - 1) as f64) as usize;
606 percentiles.insert(format!("p{}", p as u8), sorted_values[idx]);
607 }
608
609 let q1 = percentiles["p25"];
611 let q3 = percentiles["p75"];
612 let iqr = q3 - q1;
613 let lower_bound = q1 - 1.5 * iqr;
614 let upper_bound = q3 + 1.5 * iqr;
615
616 let outliers: Vec<f64> = sorted_values
617 .iter()
618 .cloned()
619 .filter(|&x| x < lower_bound || x > upper_bound)
620 .take(100) .collect();
622
623 let mean = sorted_values.iter().sum::<f64>() / sorted_values.len() as f64;
625 let variance = sorted_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
626 / sorted_values.len() as f64;
627 let std = variance.sqrt();
628
629 let skewness = if std > 0.0 {
630 sorted_values.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f64>()
631 / sorted_values.len() as f64
632 } else {
633 0.0
634 };
635
636 let kurtosis = if std > 0.0 {
637 sorted_values.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f64>()
638 / sorted_values.len() as f64
639 - 3.0
640 } else {
641 0.0
642 };
643
644 Ok(TensorDistribution {
645 histogram,
646 percentiles,
647 outliers,
648 skewness,
649 kurtosis,
650 })
651 }
652
653 fn should_compute_distribution(&self) -> bool {
654 self.config.sampling_rate >= 1.0
655 || (self.config.sampling_rate > 0.0 && random::<f32>() < self.config.sampling_rate)
656 }
657
658 fn check_tensor_alerts(&mut self, tensor_info: &TensorInfo) -> Result<()> {
659 if tensor_info.stats.nan_count > 0 {
661 self.alerts.push(TensorAlert {
662 id: Uuid::new_v4(),
663 tensor_id: tensor_info.id,
664 tensor_name: tensor_info.name.clone(),
665 alert_type: TensorAlertType::NaNValues,
666 severity: AlertSeverity::Critical,
667 message: format!(
668 "Found {} NaN values in tensor '{}'",
669 tensor_info.stats.nan_count, tensor_info.name
670 ),
671 timestamp: chrono::Utc::now(),
672 });
673 }
674
675 if tensor_info.stats.inf_count > 0 {
677 self.alerts.push(TensorAlert {
678 id: Uuid::new_v4(),
679 tensor_id: tensor_info.id,
680 tensor_name: tensor_info.name.clone(),
681 alert_type: TensorAlertType::InfiniteValues,
682 severity: AlertSeverity::Critical,
683 message: format!(
684 "Found {} infinite values in tensor '{}'",
685 tensor_info.stats.inf_count, tensor_info.name
686 ),
687 timestamp: chrono::Utc::now(),
688 });
689 }
690
691 if tensor_info.stats.max.abs() > 1e10 || tensor_info.stats.min.abs() > 1e10 {
693 self.alerts.push(TensorAlert {
694 id: Uuid::new_v4(),
695 tensor_id: tensor_info.id,
696 tensor_name: tensor_info.name.clone(),
697 alert_type: TensorAlertType::ExtremeValues,
698 severity: AlertSeverity::Warning,
699 message: format!(
700 "Extreme values detected in tensor '{}': min={:.2e}, max={:.2e}",
701 tensor_info.name, tensor_info.stats.min, tensor_info.stats.max
702 ),
703 timestamp: chrono::Utc::now(),
704 });
705 }
706
707 Ok(())
708 }
709
710 fn check_gradient_alerts_with_data(
711 &mut self,
712 tensor_id: Uuid,
713 tensor_name: &str,
714 grad_stats: Option<TensorStats>,
715 ) -> Result<()> {
716 if let Some(ref stats) = grad_stats {
717 if stats.l2_norm < 1e-8 {
719 self.alerts.push(TensorAlert {
720 id: Uuid::new_v4(),
721 tensor_id,
722 tensor_name: tensor_name.to_string(),
723 alert_type: TensorAlertType::VanishingGradients,
724 severity: AlertSeverity::Warning,
725 message: format!(
726 "Vanishing gradients detected in '{}': L2 norm = {:.2e}",
727 tensor_name, stats.l2_norm
728 ),
729 timestamp: chrono::Utc::now(),
730 });
731 }
732
733 if stats.l2_norm > 100.0 {
735 self.alerts.push(TensorAlert {
736 id: Uuid::new_v4(),
737 tensor_id,
738 tensor_name: tensor_name.to_string(),
739 alert_type: TensorAlertType::ExplodingGradients,
740 severity: AlertSeverity::Critical,
741 message: format!(
742 "Exploding gradients detected in '{}': L2 norm = {:.2e}",
743 tensor_name, stats.l2_norm
744 ),
745 timestamp: chrono::Utc::now(),
746 });
747 }
748 }
749
750 Ok(())
751 }
752
753 fn compute_mse(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
754 (stats1.mean - stats2.mean).powi(2)
756 }
757
758 fn compute_mae(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
759 (stats1.mean - stats2.mean).abs()
761 }
762
763 fn compute_cosine_similarity(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
764 if stats1.l2_norm == 0.0 || stats2.l2_norm == 0.0 {
766 0.0
767 } else {
768 (stats1.mean * stats2.mean) / (stats1.l2_norm * stats2.l2_norm)
769 }
770 }
771
772 fn compute_correlation(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
773 if stats1.std == 0.0 || stats2.std == 0.0 {
775 0.0
776 } else {
777 0.5 }
779 }
780
781 fn compute_spectral_analysis(
784 &self,
785 values: &[f64],
786 shape: &[usize],
787 ) -> Result<Option<SpectralAnalysis>> {
788 if shape.len() != 2 || values.len() < 4 {
790 return Ok(None);
791 }
792
793 let rows = shape[0];
794 let cols = shape[1];
795
796 let largest_eigenvalue = self.power_iteration(values, rows, cols)?;
798
799 let frobenius_norm = (values.iter().map(|x| x * x).sum::<f64>()).sqrt();
801 let condition_number = if frobenius_norm > 1e-12 {
802 largest_eigenvalue / frobenius_norm.max(1e-12)
803 } else {
804 f64::INFINITY
805 };
806
807 let rank = values.iter().filter(|&&x| x.abs() > 1e-10).count().min(rows.min(cols));
809
810 Ok(Some(SpectralAnalysis {
811 eigenvalues: vec![largest_eigenvalue], condition_number,
813 rank,
814 spectral_norm: largest_eigenvalue,
815 }))
816 }
817
818 fn power_iteration(&self, matrix: &[f64], rows: usize, cols: usize) -> Result<f64> {
819 if rows != cols {
820 return Ok(0.0); }
822
823 let n = rows;
824 let mut x = vec![1.0; n];
825 let mut lambda = 0.0;
826
827 for _ in 0..10 {
829 let mut ax = vec![0.0; n];
830
831 for i in 0..n {
833 for j in 0..n {
834 ax[i] += matrix[i * n + j] * x[j];
835 }
836 }
837
838 let dot_ax_x: f64 = ax.iter().zip(x.iter()).map(|(a, b)| a * b).sum();
840 let dot_x_x: f64 = x.iter().map(|a| a * a).sum();
841
842 if dot_x_x > 1e-12 {
843 lambda = dot_ax_x / dot_x_x;
844 }
845
846 let norm = (ax.iter().map(|a| a * a).sum::<f64>()).sqrt();
848 if norm > 1e-12 {
849 x = ax.iter().map(|a| a / norm).collect();
850 }
851 }
852
853 Ok(lambda.abs())
854 }
855
856 fn compute_information_content(&self, values: &[f64]) -> Result<InformationContent> {
857 let mut histogram = std::collections::HashMap::new();
859 let quantization_levels = 100;
860 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
861 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
862 let range = max_val - min_val;
863
864 if range > 1e-12 {
865 for &value in values {
866 let bucket = ((value - min_val) / range * quantization_levels as f64) as usize;
867 let bucket = bucket.min(quantization_levels - 1);
868 *histogram.entry(bucket).or_insert(0) += 1;
869 }
870 }
871
872 let total_count = values.len() as f64;
873 let entropy = if total_count > 0.0 {
874 histogram
875 .values()
876 .map(|&count| {
877 let p = count as f64 / total_count;
878 if p > 0.0 {
879 -p * p.log2()
880 } else {
881 0.0
882 }
883 })
884 .sum()
885 } else {
886 0.0
887 };
888
889 let effective_rank = if entropy > 0.0 { 2.0_f64.powf(entropy) } else { 1.0 };
891
892 let mut sorted_values = values.to_vec();
894 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
895 sorted_values.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
896 let unique_values = sorted_values.len();
897 let compression_ratio = unique_values as f64 / values.len() as f64;
898
899 Ok(InformationContent {
900 entropy,
901 mutual_information: 0.0, effective_rank,
903 compression_ratio,
904 })
905 }
906
907 fn compute_stability_metrics(&self, values: &[f64]) -> Result<StabilityMetrics> {
908 let mean = values.iter().sum::<f64>() / values.len() as f64;
910 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
911 let std_dev = variance.sqrt();
912
913 let numerical_stability = if std_dev > 1e-12 {
914 1.0 / (1.0 + std_dev / mean.abs().max(1e-12))
915 } else {
916 1.0
917 };
918
919 let max_abs = values.iter().map(|x| x.abs()).fold(0.0, f64::max);
921 let perturbation_sensitivity = if max_abs > 1e-12 { std_dev / max_abs } else { 0.0 };
922
923 let robustness_score = numerical_stability * (1.0 - perturbation_sensitivity.min(1.0));
925
926 Ok(StabilityMetrics {
927 numerical_stability,
928 gradient_stability: 0.8, perturbation_sensitivity,
930 robustness_score,
931 })
932 }
933
934 fn compute_relationship_analysis(&self, values: &[f64]) -> Result<RelationshipAnalysis> {
935 let mut cross_correlations = HashMap::new();
937 let mut dependency_strength = HashMap::new();
938 let causal_relationships = Vec::new(); for tensor_info in self.tracked_tensors.values() {
942 if tensor_info.stats.total_elements > 0 {
943 let correlation =
944 self.compute_simple_correlation(values, &[tensor_info.stats.mean]);
945 cross_correlations.insert(tensor_info.id, correlation);
946 dependency_strength.insert(tensor_info.id, correlation.abs());
947 }
948 }
949
950 Ok(RelationshipAnalysis {
951 cross_correlations,
952 dependency_strength,
953 causal_relationships,
954 })
955 }
956
957 fn compute_simple_correlation(&self, values1: &[f64], values2: &[f64]) -> f64 {
958 if values1.is_empty() || values2.is_empty() {
959 return 0.0;
960 }
961
962 let mean1 = values1.iter().sum::<f64>() / values1.len() as f64;
963 let mean2 = values2.iter().sum::<f64>() / values2.len() as f64;
964
965 let min_len = values1.len().min(values2.len());
966 let numerator: f64 = values1
967 .iter()
968 .zip(values2.iter())
969 .take(min_len)
970 .map(|(x, y)| (x - mean1) * (y - mean2))
971 .sum();
972
973 let sum_sq1: f64 = values1.iter().take(min_len).map(|x| (x - mean1).powi(2)).sum();
974 let sum_sq2: f64 = values2.iter().take(min_len).map(|y| (y - mean2).powi(2)).sum();
975
976 let denominator = (sum_sq1 * sum_sq2).sqrt();
977 if denominator > 1e-12 {
978 numerator / denominator
979 } else {
980 0.0
981 }
982 }
983
984 fn compute_summary_stats(&self) -> HashMap<String, f64> {
985 let mut stats = HashMap::new();
986
987 if !self.tracked_tensors.is_empty() {
988 let values: Vec<f64> = self.tracked_tensors.values().map(|t| t.stats.mean).collect();
989 stats.insert(
990 "mean_of_means".to_string(),
991 values.iter().sum::<f64>() / values.len() as f64,
992 );
993
994 let total_memory: usize =
995 self.tracked_tensors.values().map(|t| t.stats.memory_usage_bytes).sum();
996 stats.insert(
997 "total_memory_mb".to_string(),
998 total_memory as f64 / (1024.0 * 1024.0),
999 );
1000
1001 let avg_sparsity: f64 =
1002 self.tracked_tensors.values().map(|t| t.stats.sparsity).sum::<f64>()
1003 / self.tracked_tensors.len() as f64;
1004 stats.insert("avg_sparsity".to_string(), avg_sparsity);
1005
1006 stats.insert(
1008 "total_dependencies".to_string(),
1009 self.dependencies.len() as f64,
1010 );
1011 stats.insert(
1012 "monitored_tensors".to_string(),
1013 self.time_series.len() as f64,
1014 );
1015 stats.insert(
1016 "active_lifecycles".to_string(),
1017 self.lifecycles.len() as f64,
1018 );
1019 }
1020
1021 stats
1022 }
1023}
1024
1025#[derive(Debug, Clone, Serialize, Deserialize)]
1027pub enum TensorAlertType {
1028 NaNValues,
1029 InfiniteValues,
1030 ExtremeValues,
1031 VanishingGradients,
1032 ExplodingGradients,
1033 MemoryUsage,
1034 ShapeMismatch,
1035}
1036
1037#[derive(Debug, Clone, Serialize, Deserialize)]
1039pub enum AlertSeverity {
1040 Info,
1041 Warning,
1042 Critical,
1043}
1044
1045#[derive(Debug, Clone, Serialize, Deserialize)]
1047pub struct TensorAlert {
1048 pub id: Uuid,
1049 pub tensor_id: Uuid,
1050 pub tensor_name: String,
1051 pub alert_type: TensorAlertType,
1052 pub severity: AlertSeverity,
1053 pub message: String,
1054 pub timestamp: chrono::DateTime<chrono::Utc>,
1055}
1056
1057#[derive(Debug, Clone, Serialize, Deserialize)]
1059pub struct TensorInspectionReport {
1060 pub total_tensors: usize,
1061 pub tensors_with_issues: usize,
1062 pub total_memory_usage: usize,
1063 pub alerts: Vec<TensorAlert>,
1064 pub comparisons: Vec<TensorComparison>,
1065 pub summary_stats: HashMap<String, f64>,
1066}
1067
1068impl TensorInspectionReport {
1069 pub fn has_nan_values(&self) -> bool {
1070 self.alerts.iter().any(|a| matches!(a.alert_type, TensorAlertType::NaNValues))
1071 }
1072
1073 pub fn has_inf_values(&self) -> bool {
1074 self.alerts
1075 .iter()
1076 .any(|a| matches!(a.alert_type, TensorAlertType::InfiniteValues))
1077 }
1078
1079 pub fn total_nan_count(&self) -> usize {
1080 self.alerts
1081 .iter()
1082 .filter(|a| matches!(a.alert_type, TensorAlertType::NaNValues))
1083 .count()
1084 }
1085
1086 pub fn total_inf_count(&self) -> usize {
1087 self.alerts
1088 .iter()
1089 .filter(|a| matches!(a.alert_type, TensorAlertType::InfiniteValues))
1090 .count()
1091 }
1092
1093 pub fn has_critical_alerts(&self) -> bool {
1094 self.alerts.iter().any(|a| matches!(a.severity, AlertSeverity::Critical))
1095 }
1096}