Skip to main content

trustformers_debug/
tensor_inspector.rs

1//! Tensor inspection and analysis tools
2
3use anyhow::Result;
4use scirs2_core::ndarray::*; // SciRS2 Integration Policy - was: use ndarray::{Array, ArrayD, Axis, IxDyn};
5use scirs2_core::random::random; // SciRS2 Integration Policy
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::fmt;
9use std::time::Instant;
10use uuid::Uuid;
11
12use crate::DebugConfig;
13
14/// Statistics about a tensor
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct TensorStats {
17    pub shape: Vec<usize>,
18    pub dtype: String,
19    pub total_elements: usize,
20    pub mean: f64,
21    pub std: f64,
22    pub min: f64,
23    pub max: f64,
24    pub median: f64,
25    pub l1_norm: f64,
26    pub l2_norm: f64,
27    pub infinity_norm: f64,
28    pub nan_count: usize,
29    pub inf_count: usize,
30    pub zero_count: usize,
31    pub memory_usage_bytes: usize,
32    pub sparsity: f64,
33}
34
35/// Distribution analysis of tensor values
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TensorDistribution {
38    pub histogram: Vec<(f64, usize)>,
39    pub percentiles: HashMap<String, f64>,
40    pub outliers: Vec<f64>,
41    pub skewness: f64,
42    pub kurtosis: f64,
43}
44
45/// Tensor metadata and tracking information
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TensorInfo {
48    pub id: Uuid,
49    pub name: String,
50    pub layer_name: Option<String>,
51    pub operation: Option<String>,
52    pub timestamp: chrono::DateTime<chrono::Utc>,
53    pub stats: TensorStats,
54    pub distribution: Option<TensorDistribution>,
55    pub gradient_stats: Option<TensorStats>,
56}
57
58/// Comparison between two tensors
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TensorComparison {
61    pub tensor1_id: Uuid,
62    pub tensor2_id: Uuid,
63    pub mse: f64,
64    pub mae: f64,
65    pub max_diff: f64,
66    pub cosine_similarity: f64,
67    pub correlation: f64,
68    pub shape_match: bool,
69    pub dtype_match: bool,
70}
71
72/// Real-time tensor monitoring data
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TensorTimeSeries {
75    pub tensor_id: Uuid,
76    pub timestamps: VecDeque<chrono::DateTime<chrono::Utc>>,
77    pub values: VecDeque<TensorStats>,
78    pub max_history: usize,
79}
80
81/// Tensor dependency tracking
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct TensorDependency {
84    pub source_id: Uuid,
85    pub target_id: Uuid,
86    pub operation: String,
87    pub weight: f64,
88}
89
90/// Tensor lifecycle event
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub enum TensorLifecycleEvent {
93    Created { size_bytes: usize },
94    Modified { operation: String },
95    Accessed { access_type: String },
96    Destroyed,
97}
98
99/// Tensor lifecycle tracking
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct TensorLifecycle {
102    pub tensor_id: Uuid,
103    pub events: Vec<(chrono::DateTime<chrono::Utc>, TensorLifecycleEvent)>,
104    pub total_accesses: usize,
105    pub creation_time: chrono::DateTime<chrono::Utc>,
106}
107
108/// Advanced tensor analysis results
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct AdvancedTensorAnalysis {
111    pub spectral_analysis: Option<SpectralAnalysis>,
112    pub information_content: InformationContent,
113    pub stability_metrics: StabilityMetrics,
114    pub relationship_analysis: RelationshipAnalysis,
115}
116
117/// Spectral analysis of tensor values
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct SpectralAnalysis {
120    pub eigenvalues: Vec<f64>,
121    pub condition_number: f64,
122    pub rank: usize,
123    pub spectral_norm: f64,
124}
125
126/// Information content metrics
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct InformationContent {
129    pub entropy: f64,
130    pub mutual_information: f64,
131    pub effective_rank: f64,
132    pub compression_ratio: f64,
133}
134
135/// Stability metrics
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct StabilityMetrics {
138    pub numerical_stability: f64,
139    pub gradient_stability: f64,
140    pub perturbation_sensitivity: f64,
141    pub robustness_score: f64,
142}
143
144/// Relationship analysis between tensors
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct RelationshipAnalysis {
147    pub cross_correlations: HashMap<Uuid, f64>,
148    pub dependency_strength: HashMap<Uuid, f64>,
149    pub causal_relationships: Vec<TensorDependency>,
150}
151
152/// Enhanced tensor inspector for detailed analysis
153#[derive(Debug)]
154pub struct TensorInspector {
155    config: DebugConfig,
156    tracked_tensors: HashMap<Uuid, TensorInfo>,
157    comparisons: Vec<TensorComparison>,
158    alerts: Vec<TensorAlert>,
159    // New advanced features
160    time_series: HashMap<Uuid, TensorTimeSeries>,
161    dependencies: Vec<TensorDependency>,
162    lifecycles: HashMap<Uuid, TensorLifecycle>,
163    monitoring_enabled: bool,
164    last_analysis_time: Option<Instant>,
165}
166
167impl TensorInspector {
168    /// Create a new tensor inspector
169    pub fn new(config: &DebugConfig) -> Self {
170        Self {
171            config: config.clone(),
172            tracked_tensors: HashMap::new(),
173            comparisons: Vec::new(),
174            alerts: Vec::new(),
175            // Initialize new advanced features
176            time_series: HashMap::new(),
177            dependencies: Vec::new(),
178            lifecycles: HashMap::new(),
179            monitoring_enabled: false,
180            last_analysis_time: None,
181        }
182    }
183
184    /// Start the tensor inspector
185    pub async fn start(&mut self) -> Result<()> {
186        tracing::info!("Starting tensor inspector");
187        Ok(())
188    }
189
190    /// Inspect a tensor and return detailed analysis
191    pub fn inspect_tensor<T>(
192        &mut self,
193        tensor: &ArrayD<T>,
194        name: &str,
195        layer_name: Option<&str>,
196        operation: Option<&str>,
197    ) -> Result<Uuid>
198    where
199        T: Clone + Into<f64> + fmt::Debug + 'static,
200    {
201        let id = Uuid::new_v4();
202
203        // Convert to f64 for analysis
204        let values: Vec<f64> = tensor.iter().map(|x| x.clone().into()).collect();
205        let shape = tensor.shape().to_vec();
206
207        let stats = self.compute_tensor_stats(&values, &shape, std::mem::size_of::<T>())?;
208        let distribution = if self.should_compute_distribution() {
209            Some(self.compute_distribution(&values)?)
210        } else {
211            None
212        };
213
214        let tensor_info = TensorInfo {
215            id,
216            name: name.to_string(),
217            layer_name: layer_name.map(|s| s.to_string()),
218            operation: operation.map(|s| s.to_string()),
219            timestamp: chrono::Utc::now(),
220            stats,
221            distribution,
222            gradient_stats: None,
223        };
224
225        // Check for alerts
226        self.check_tensor_alerts(&tensor_info)?;
227
228        // Store tensor info if we have space
229        if self.tracked_tensors.len() < self.config.max_tracked_tensors {
230            self.tracked_tensors.insert(id, tensor_info.clone());
231        }
232
233        // Advanced features integration
234        self.record_lifecycle_event(
235            id,
236            TensorLifecycleEvent::Created {
237                size_bytes: std::mem::size_of::<T>() * tensor.len(),
238            },
239        );
240
241        if self.monitoring_enabled {
242            if let Some(tensor_info) = self.tracked_tensors.get(&id) {
243                self.update_time_series(id, tensor_info.stats.clone());
244            }
245        }
246
247        Ok(id)
248    }
249
250    /// Inspect tensor gradients
251    pub fn inspect_gradients<T>(&mut self, tensor_id: Uuid, gradients: &ArrayD<T>) -> Result<()>
252    where
253        T: Clone + Into<f64> + fmt::Debug + 'static,
254    {
255        let values: Vec<f64> = gradients.iter().map(|x| x.clone().into()).collect();
256        let shape = gradients.shape().to_vec();
257
258        let gradient_stats =
259            self.compute_tensor_stats(&values, &shape, std::mem::size_of::<T>())?;
260
261        if let Some(tensor_info) = self.tracked_tensors.get_mut(&tensor_id) {
262            tensor_info.gradient_stats = Some(gradient_stats);
263        }
264
265        // Check for gradient-specific alerts - get the data we need first
266        let tensor_info_for_alerts = self
267            .tracked_tensors
268            .get(&tensor_id)
269            .map(|info| (info.id, info.name.clone(), info.gradient_stats.clone()));
270
271        if let Some((id, name, grad_stats)) = tensor_info_for_alerts {
272            self.check_gradient_alerts_with_data(id, &name, grad_stats)?;
273        }
274
275        Ok(())
276    }
277
278    /// Compare two tensors
279    pub fn compare_tensors(&mut self, id1: Uuid, id2: Uuid) -> Result<TensorComparison> {
280        let tensor1 = self
281            .tracked_tensors
282            .get(&id1)
283            .ok_or_else(|| anyhow::anyhow!("Tensor {} not found", id1))?;
284        let tensor2 = self
285            .tracked_tensors
286            .get(&id2)
287            .ok_or_else(|| anyhow::anyhow!("Tensor {} not found", id2))?;
288
289        let comparison = TensorComparison {
290            tensor1_id: id1,
291            tensor2_id: id2,
292            mse: self.compute_mse(&tensor1.stats, &tensor2.stats),
293            mae: self.compute_mae(&tensor1.stats, &tensor2.stats),
294            max_diff: (tensor1.stats.max - tensor2.stats.max).abs(),
295            cosine_similarity: self.compute_cosine_similarity(&tensor1.stats, &tensor2.stats),
296            correlation: self.compute_correlation(&tensor1.stats, &tensor2.stats),
297            shape_match: tensor1.stats.shape == tensor2.stats.shape,
298            dtype_match: tensor1.stats.dtype == tensor2.stats.dtype,
299        };
300
301        self.comparisons.push(comparison.clone());
302        Ok(comparison)
303    }
304
305    /// Get tensor information by ID
306    pub fn get_tensor_info(&self, id: Uuid) -> Option<&TensorInfo> {
307        self.tracked_tensors.get(&id)
308    }
309
310    /// Get all tracked tensors
311    pub fn get_all_tensors(&self) -> Vec<&TensorInfo> {
312        self.tracked_tensors.values().collect()
313    }
314
315    /// Get tensors by layer name
316    pub fn get_tensors_by_layer(&self, layer_name: &str) -> Vec<&TensorInfo> {
317        self.tracked_tensors
318            .values()
319            .filter(|info| info.layer_name.as_ref() == Some(&layer_name.to_string()))
320            .collect()
321    }
322
323    /// Get all alerts
324    pub fn get_alerts(&self) -> &[TensorAlert] {
325        &self.alerts
326    }
327
328    /// Clear tracking data
329    pub fn clear(&mut self) {
330        self.tracked_tensors.clear();
331        self.comparisons.clear();
332        self.alerts.clear();
333        // Clear new advanced features
334        self.time_series.clear();
335        self.dependencies.clear();
336        self.lifecycles.clear();
337        self.last_analysis_time = None;
338    }
339
340    /// Generate inspection report
341    pub async fn generate_report(&self) -> Result<TensorInspectionReport> {
342        let total_tensors = self.tracked_tensors.len();
343        let tensors_with_issues = self
344            .tracked_tensors
345            .values()
346            .filter(|info| info.stats.nan_count > 0 || info.stats.inf_count > 0)
347            .count();
348
349        let memory_usage =
350            self.tracked_tensors.values().map(|info| info.stats.memory_usage_bytes).sum();
351
352        Ok(TensorInspectionReport {
353            total_tensors,
354            tensors_with_issues,
355            total_memory_usage: memory_usage,
356            alerts: self.alerts.clone(),
357            comparisons: self.comparisons.clone(),
358            summary_stats: self.compute_summary_stats(),
359        })
360    }
361
362    // Advanced debugging features
363
364    /// Enable real-time tensor monitoring
365    pub fn enable_monitoring(&mut self, enable: bool) {
366        self.monitoring_enabled = enable;
367        if enable {
368            tracing::info!("Real-time tensor monitoring enabled");
369        } else {
370            tracing::info!("Real-time tensor monitoring disabled");
371        }
372    }
373
374    /// Track tensor dependency
375    pub fn track_dependency(
376        &mut self,
377        source_id: Uuid,
378        target_id: Uuid,
379        operation: &str,
380        weight: f64,
381    ) {
382        let dependency = TensorDependency {
383            source_id,
384            target_id,
385            operation: operation.to_string(),
386            weight,
387        };
388        self.dependencies.push(dependency);
389    }
390
391    /// Record tensor lifecycle event
392    pub fn record_lifecycle_event(&mut self, tensor_id: Uuid, event: TensorLifecycleEvent) {
393        let lifecycle = self.lifecycles.entry(tensor_id).or_insert_with(|| TensorLifecycle {
394            tensor_id,
395            events: Vec::new(),
396            total_accesses: 0,
397            creation_time: chrono::Utc::now(),
398        });
399
400        lifecycle.events.push((chrono::Utc::now(), event.clone()));
401
402        if matches!(event, TensorLifecycleEvent::Accessed { .. }) {
403            lifecycle.total_accesses += 1;
404        }
405    }
406
407    /// Update tensor time series data
408    pub fn update_time_series(&mut self, tensor_id: Uuid, stats: TensorStats) {
409        if !self.monitoring_enabled {
410            return;
411        }
412
413        let time_series = self.time_series.entry(tensor_id).or_insert_with(|| TensorTimeSeries {
414            tensor_id,
415            timestamps: VecDeque::new(),
416            values: VecDeque::new(),
417            max_history: 1000, // Default history size
418        });
419
420        time_series.timestamps.push_back(chrono::Utc::now());
421        time_series.values.push_back(stats);
422
423        // Maintain maximum history size
424        while time_series.timestamps.len() > time_series.max_history {
425            time_series.timestamps.pop_front();
426            time_series.values.pop_front();
427        }
428    }
429
430    /// Perform advanced tensor analysis
431    pub fn perform_advanced_analysis<T>(&self, tensor: &ArrayD<T>) -> Result<AdvancedTensorAnalysis>
432    where
433        T: Clone + Into<f64> + fmt::Debug + 'static,
434    {
435        let values: Vec<f64> = tensor.iter().map(|x| x.clone().into()).collect();
436
437        Ok(AdvancedTensorAnalysis {
438            spectral_analysis: self.compute_spectral_analysis(&values, tensor.shape())?,
439            information_content: self.compute_information_content(&values)?,
440            stability_metrics: self.compute_stability_metrics(&values)?,
441            relationship_analysis: self.compute_relationship_analysis(&values)?,
442        })
443    }
444
445    /// Detect tensor anomalies using advanced algorithms
446    pub fn detect_advanced_anomalies(&self, tensor_id: Uuid) -> Result<Vec<TensorAlert>> {
447        let mut alerts = Vec::new();
448
449        if let Some(time_series) = self.time_series.get(&tensor_id) {
450            // Detect drift in tensor statistics over time
451            if time_series.values.len() >= 10 {
452                let recent_mean =
453                    time_series.values.iter().rev().take(5).map(|stats| stats.mean).sum::<f64>()
454                        / 5.0;
455                let historical_mean =
456                    time_series.values.iter().take(5).map(|stats| stats.mean).sum::<f64>() / 5.0;
457
458                let drift_ratio =
459                    (recent_mean - historical_mean).abs() / historical_mean.abs().max(1e-8);
460
461                if drift_ratio > 0.5 {
462                    if let Some(tensor_info) = self.tracked_tensors.get(&tensor_id) {
463                        alerts.push(TensorAlert {
464                            id: Uuid::new_v4(),
465                            tensor_id,
466                            tensor_name: tensor_info.name.clone(),
467                            alert_type: TensorAlertType::ExtremeValues,
468                            severity: AlertSeverity::Warning,
469                            message: format!(
470                                "Detected statistical drift in tensor '{}': {:.2}% change",
471                                tensor_info.name,
472                                drift_ratio * 100.0
473                            ),
474                            timestamp: chrono::Utc::now(),
475                        });
476                    }
477                }
478            }
479        }
480
481        Ok(alerts)
482    }
483
484    /// Get tensor dependencies
485    pub fn get_dependencies(&self) -> &[TensorDependency] {
486        &self.dependencies
487    }
488
489    /// Get tensor lifecycle
490    pub fn get_lifecycle(&self, tensor_id: Uuid) -> Option<&TensorLifecycle> {
491        self.lifecycles.get(&tensor_id)
492    }
493
494    /// Get tensor time series
495    pub fn get_time_series(&self, tensor_id: Uuid) -> Option<&TensorTimeSeries> {
496        self.time_series.get(&tensor_id)
497    }
498
499    /// Analyze tensor relationships
500    pub fn analyze_tensor_relationships(&self) -> HashMap<Uuid, Vec<Uuid>> {
501        let mut relationships = HashMap::new();
502
503        for dependency in &self.dependencies {
504            relationships
505                .entry(dependency.source_id)
506                .or_insert_with(Vec::new)
507                .push(dependency.target_id);
508        }
509
510        relationships
511    }
512
513    /// Get frequently accessed tensors
514    pub fn get_frequent_tensors(&self, min_accesses: usize) -> Vec<Uuid> {
515        self.lifecycles
516            .iter()
517            .filter(|(_, lifecycle)| lifecycle.total_accesses >= min_accesses)
518            .map(|(id, _)| *id)
519            .collect()
520    }
521
522    // Private helper methods
523
524    fn compute_tensor_stats(
525        &self,
526        values: &[f64],
527        shape: &[usize],
528        element_size: usize,
529    ) -> Result<TensorStats> {
530        let total_elements = values.len();
531        let mean = values.iter().sum::<f64>() / total_elements as f64;
532
533        let variance =
534            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / total_elements as f64;
535        let std = variance.sqrt();
536
537        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
538        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
539
540        let mut sorted_values = values.to_vec();
541        // Filter out NaN values before sorting to avoid panic
542        sorted_values.retain(|x| !x.is_nan());
543        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544        let median = if sorted_values.is_empty() {
545            f64::NAN
546        } else if sorted_values.len() % 2 == 0 {
547            (sorted_values[sorted_values.len() / 2 - 1] + sorted_values[sorted_values.len() / 2])
548                / 2.0
549        } else {
550            sorted_values[sorted_values.len() / 2]
551        };
552
553        let l1_norm = values.iter().map(|x| x.abs()).sum::<f64>();
554        let l2_norm = values.iter().map(|x| x * x).sum::<f64>().sqrt();
555        let infinity_norm = values.iter().map(|x| x.abs()).fold(0.0, f64::max);
556
557        let nan_count = values.iter().filter(|x| x.is_nan()).count();
558        let inf_count = values.iter().filter(|x| x.is_infinite()).count();
559        let zero_count = values.iter().filter(|x| **x == 0.0).count();
560
561        let memory_usage_bytes = total_elements * element_size;
562        let sparsity = zero_count as f64 / total_elements as f64;
563
564        Ok(TensorStats {
565            shape: shape.to_vec(),
566            dtype: "f64".to_string(), // Simplified for now
567            total_elements,
568            mean,
569            std,
570            min,
571            max,
572            median,
573            l1_norm,
574            l2_norm,
575            infinity_norm,
576            nan_count,
577            inf_count,
578            zero_count,
579            memory_usage_bytes,
580            sparsity,
581        })
582    }
583
584    fn compute_distribution(&self, values: &[f64]) -> Result<TensorDistribution> {
585        // Simple histogram computation
586        let num_bins = 50;
587        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
588        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
589        let bin_width = (max - min) / num_bins as f64;
590
591        let mut histogram = vec![(0.0, 0); num_bins];
592        for &value in values {
593            if !value.is_finite() {
594                continue;
595            }
596            let bin_idx = ((value - min) / bin_width).floor() as usize;
597            let bin_idx = bin_idx.min(num_bins - 1);
598            histogram[bin_idx].0 = min + bin_idx as f64 * bin_width;
599            histogram[bin_idx].1 += 1;
600        }
601
602        // Compute percentiles
603        let mut sorted_values =
604            values.iter().cloned().filter(|x| x.is_finite()).collect::<Vec<_>>();
605        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
606
607        let mut percentiles = HashMap::new();
608        for &p in &[5.0, 25.0, 50.0, 75.0, 95.0, 99.0] {
609            let idx = ((p / 100.0) * (sorted_values.len() - 1) as f64) as usize;
610            percentiles.insert(format!("p{}", p as u8), sorted_values[idx]);
611        }
612
613        // Identify outliers (simple method using IQR)
614        let q1 = percentiles["p25"];
615        let q3 = percentiles["p75"];
616        let iqr = q3 - q1;
617        let lower_bound = q1 - 1.5 * iqr;
618        let upper_bound = q3 + 1.5 * iqr;
619
620        let outliers: Vec<f64> = sorted_values
621            .iter()
622            .cloned()
623            .filter(|&x| x < lower_bound || x > upper_bound)
624            .take(100) // Limit outliers to avoid memory issues
625            .collect();
626
627        // Basic skewness and kurtosis (simplified formulas)
628        let mean = sorted_values.iter().sum::<f64>() / sorted_values.len() as f64;
629        let variance = sorted_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
630            / sorted_values.len() as f64;
631        let std = variance.sqrt();
632
633        let skewness = if std > 0.0 {
634            sorted_values.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f64>()
635                / sorted_values.len() as f64
636        } else {
637            0.0
638        };
639
640        let kurtosis = if std > 0.0 {
641            sorted_values.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f64>()
642                / sorted_values.len() as f64
643                - 3.0
644        } else {
645            0.0
646        };
647
648        Ok(TensorDistribution {
649            histogram,
650            percentiles,
651            outliers,
652            skewness,
653            kurtosis,
654        })
655    }
656
657    fn should_compute_distribution(&self) -> bool {
658        self.config.sampling_rate >= 1.0
659            || (self.config.sampling_rate > 0.0 && random::<f32>() < self.config.sampling_rate)
660    }
661
662    fn check_tensor_alerts(&mut self, tensor_info: &TensorInfo) -> Result<()> {
663        // Check for NaN values
664        if tensor_info.stats.nan_count > 0 {
665            self.alerts.push(TensorAlert {
666                id: Uuid::new_v4(),
667                tensor_id: tensor_info.id,
668                tensor_name: tensor_info.name.clone(),
669                alert_type: TensorAlertType::NaNValues,
670                severity: AlertSeverity::Critical,
671                message: format!(
672                    "Found {} NaN values in tensor '{}'",
673                    tensor_info.stats.nan_count, tensor_info.name
674                ),
675                timestamp: chrono::Utc::now(),
676            });
677        }
678
679        // Check for infinite values
680        if tensor_info.stats.inf_count > 0 {
681            self.alerts.push(TensorAlert {
682                id: Uuid::new_v4(),
683                tensor_id: tensor_info.id,
684                tensor_name: tensor_info.name.clone(),
685                alert_type: TensorAlertType::InfiniteValues,
686                severity: AlertSeverity::Critical,
687                message: format!(
688                    "Found {} infinite values in tensor '{}'",
689                    tensor_info.stats.inf_count, tensor_info.name
690                ),
691                timestamp: chrono::Utc::now(),
692            });
693        }
694
695        // Check for extreme values
696        if tensor_info.stats.max.abs() > 1e10 || tensor_info.stats.min.abs() > 1e10 {
697            self.alerts.push(TensorAlert {
698                id: Uuid::new_v4(),
699                tensor_id: tensor_info.id,
700                tensor_name: tensor_info.name.clone(),
701                alert_type: TensorAlertType::ExtremeValues,
702                severity: AlertSeverity::Warning,
703                message: format!(
704                    "Extreme values detected in tensor '{}': min={:.2e}, max={:.2e}",
705                    tensor_info.name, tensor_info.stats.min, tensor_info.stats.max
706                ),
707                timestamp: chrono::Utc::now(),
708            });
709        }
710
711        Ok(())
712    }
713
714    fn check_gradient_alerts_with_data(
715        &mut self,
716        tensor_id: Uuid,
717        tensor_name: &str,
718        grad_stats: Option<TensorStats>,
719    ) -> Result<()> {
720        if let Some(ref stats) = grad_stats {
721            // Check for vanishing gradients
722            if stats.l2_norm < 1e-8 {
723                self.alerts.push(TensorAlert {
724                    id: Uuid::new_v4(),
725                    tensor_id,
726                    tensor_name: tensor_name.to_string(),
727                    alert_type: TensorAlertType::VanishingGradients,
728                    severity: AlertSeverity::Warning,
729                    message: format!(
730                        "Vanishing gradients detected in '{}': L2 norm = {:.2e}",
731                        tensor_name, stats.l2_norm
732                    ),
733                    timestamp: chrono::Utc::now(),
734                });
735            }
736
737            // Check for exploding gradients
738            if stats.l2_norm > 100.0 {
739                self.alerts.push(TensorAlert {
740                    id: Uuid::new_v4(),
741                    tensor_id,
742                    tensor_name: tensor_name.to_string(),
743                    alert_type: TensorAlertType::ExplodingGradients,
744                    severity: AlertSeverity::Critical,
745                    message: format!(
746                        "Exploding gradients detected in '{}': L2 norm = {:.2e}",
747                        tensor_name, stats.l2_norm
748                    ),
749                    timestamp: chrono::Utc::now(),
750                });
751            }
752        }
753
754        Ok(())
755    }
756
757    fn compute_mse(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
758        // Simplified MSE using means (would need actual tensor data for real MSE)
759        (stats1.mean - stats2.mean).powi(2)
760    }
761
762    fn compute_mae(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
763        // Simplified MAE using means
764        (stats1.mean - stats2.mean).abs()
765    }
766
767    fn compute_cosine_similarity(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
768        // Simplified using L2 norms (would need actual tensors for real cosine similarity)
769        if stats1.l2_norm == 0.0 || stats2.l2_norm == 0.0 {
770            0.0
771        } else {
772            (stats1.mean * stats2.mean) / (stats1.l2_norm * stats2.l2_norm)
773        }
774    }
775
776    fn compute_correlation(&self, stats1: &TensorStats, stats2: &TensorStats) -> f64 {
777        // Simplified correlation (would need actual tensor data for real correlation)
778        if stats1.std == 0.0 || stats2.std == 0.0 {
779            0.0
780        } else {
781            0.5 // Placeholder
782        }
783    }
784
785    // Advanced analysis helper methods
786
787    fn compute_spectral_analysis(
788        &self,
789        values: &[f64],
790        shape: &[usize],
791    ) -> Result<Option<SpectralAnalysis>> {
792        // Only perform spectral analysis for 2D matrices
793        if shape.len() != 2 || values.len() < 4 {
794            return Ok(None);
795        }
796
797        let rows = shape[0];
798        let cols = shape[1];
799
800        // Simple eigenvalue estimation using power iteration for the largest eigenvalue
801        let largest_eigenvalue = self.power_iteration(values, rows, cols)?;
802
803        // Estimate condition number using Frobenius norm ratio
804        let frobenius_norm = (values.iter().map(|x| x * x).sum::<f64>()).sqrt();
805        let condition_number = if frobenius_norm > 1e-12 {
806            largest_eigenvalue / frobenius_norm.max(1e-12)
807        } else {
808            f64::INFINITY
809        };
810
811        // Estimate rank by counting non-zero singular values (simplified)
812        let rank = values.iter().filter(|&&x| x.abs() > 1e-10).count().min(rows.min(cols));
813
814        Ok(Some(SpectralAnalysis {
815            eigenvalues: vec![largest_eigenvalue], // Simplified - only largest
816            condition_number,
817            rank,
818            spectral_norm: largest_eigenvalue,
819        }))
820    }
821
822    fn power_iteration(&self, matrix: &[f64], rows: usize, cols: usize) -> Result<f64> {
823        if rows != cols {
824            return Ok(0.0); // Non-square matrices
825        }
826
827        let n = rows;
828        let mut x = vec![1.0; n];
829        let mut lambda = 0.0;
830
831        // Simple power iteration (few iterations for performance)
832        for _ in 0..10 {
833            let mut ax = vec![0.0; n];
834
835            // Matrix-vector multiplication: Ax
836            for i in 0..n {
837                for j in 0..n {
838                    ax[i] += matrix[i * n + j] * x[j];
839                }
840            }
841
842            // Compute Rayleigh quotient
843            let dot_ax_x: f64 = ax.iter().zip(x.iter()).map(|(a, b)| a * b).sum();
844            let dot_x_x: f64 = x.iter().map(|a| a * a).sum();
845
846            if dot_x_x > 1e-12 {
847                lambda = dot_ax_x / dot_x_x;
848            }
849
850            // Normalize
851            let norm = (ax.iter().map(|a| a * a).sum::<f64>()).sqrt();
852            if norm > 1e-12 {
853                x = ax.iter().map(|a| a / norm).collect();
854            }
855        }
856
857        Ok(lambda.abs())
858    }
859
860    fn compute_information_content(&self, values: &[f64]) -> Result<InformationContent> {
861        // Shannon entropy calculation
862        let mut histogram = std::collections::HashMap::new();
863        let quantization_levels = 100;
864        let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
865        let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
866        let range = max_val - min_val;
867
868        if range > 1e-12 {
869            for &value in values {
870                let bucket = ((value - min_val) / range * quantization_levels as f64) as usize;
871                let bucket = bucket.min(quantization_levels - 1);
872                *histogram.entry(bucket).or_insert(0) += 1;
873            }
874        }
875
876        let total_count = values.len() as f64;
877        let entropy = if total_count > 0.0 {
878            histogram
879                .values()
880                .map(|&count| {
881                    let p = count as f64 / total_count;
882                    if p > 0.0 {
883                        -p * p.log2()
884                    } else {
885                        0.0
886                    }
887                })
888                .sum()
889        } else {
890            0.0
891        };
892
893        // Effective rank estimation using entropy
894        let effective_rank = if entropy > 0.0 { 2.0_f64.powf(entropy) } else { 1.0 };
895
896        // Simple compression ratio estimation
897        let mut sorted_values = values.to_vec();
898        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
899        sorted_values.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
900        let unique_values = sorted_values.len();
901        let compression_ratio = unique_values as f64 / values.len() as f64;
902
903        Ok(InformationContent {
904            entropy,
905            mutual_information: 0.0, // Would need multiple tensors to compute
906            effective_rank,
907            compression_ratio,
908        })
909    }
910
911    fn compute_stability_metrics(&self, values: &[f64]) -> Result<StabilityMetrics> {
912        // Numerical stability based on condition of values
913        let mean = values.iter().sum::<f64>() / values.len() as f64;
914        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
915        let std_dev = variance.sqrt();
916
917        let numerical_stability = if std_dev > 1e-12 {
918            1.0 / (1.0 + std_dev / mean.abs().max(1e-12))
919        } else {
920            1.0
921        };
922
923        // Simple perturbation sensitivity
924        let max_abs = values.iter().map(|x| x.abs()).fold(0.0, f64::max);
925        let perturbation_sensitivity = if max_abs > 1e-12 { std_dev / max_abs } else { 0.0 };
926
927        // Overall robustness score
928        let robustness_score = numerical_stability * (1.0 - perturbation_sensitivity.min(1.0));
929
930        Ok(StabilityMetrics {
931            numerical_stability,
932            gradient_stability: 0.8, // Placeholder - would need gradient info
933            perturbation_sensitivity,
934            robustness_score,
935        })
936    }
937
938    fn compute_relationship_analysis(&self, values: &[f64]) -> Result<RelationshipAnalysis> {
939        // Simple relationship analysis within the tensor
940        let mut cross_correlations = HashMap::new();
941        let mut dependency_strength = HashMap::new();
942        let causal_relationships = Vec::new(); // Would need temporal data
943
944        // For now, just compute some basic relationships
945        for tensor_info in self.tracked_tensors.values() {
946            if tensor_info.stats.total_elements > 0 {
947                let correlation =
948                    self.compute_simple_correlation(values, &[tensor_info.stats.mean]);
949                cross_correlations.insert(tensor_info.id, correlation);
950                dependency_strength.insert(tensor_info.id, correlation.abs());
951            }
952        }
953
954        Ok(RelationshipAnalysis {
955            cross_correlations,
956            dependency_strength,
957            causal_relationships,
958        })
959    }
960
961    fn compute_simple_correlation(&self, values1: &[f64], values2: &[f64]) -> f64 {
962        if values1.is_empty() || values2.is_empty() {
963            return 0.0;
964        }
965
966        let mean1 = values1.iter().sum::<f64>() / values1.len() as f64;
967        let mean2 = values2.iter().sum::<f64>() / values2.len() as f64;
968
969        let min_len = values1.len().min(values2.len());
970        let numerator: f64 = values1
971            .iter()
972            .zip(values2.iter())
973            .take(min_len)
974            .map(|(x, y)| (x - mean1) * (y - mean2))
975            .sum();
976
977        let sum_sq1: f64 = values1.iter().take(min_len).map(|x| (x - mean1).powi(2)).sum();
978        let sum_sq2: f64 = values2.iter().take(min_len).map(|y| (y - mean2).powi(2)).sum();
979
980        let denominator = (sum_sq1 * sum_sq2).sqrt();
981        if denominator > 1e-12 {
982            numerator / denominator
983        } else {
984            0.0
985        }
986    }
987
988    fn compute_summary_stats(&self) -> HashMap<String, f64> {
989        let mut stats = HashMap::new();
990
991        if !self.tracked_tensors.is_empty() {
992            let values: Vec<f64> = self.tracked_tensors.values().map(|t| t.stats.mean).collect();
993            stats.insert(
994                "mean_of_means".to_string(),
995                values.iter().sum::<f64>() / values.len() as f64,
996            );
997
998            let total_memory: usize =
999                self.tracked_tensors.values().map(|t| t.stats.memory_usage_bytes).sum();
1000            stats.insert(
1001                "total_memory_mb".to_string(),
1002                total_memory as f64 / (1024.0 * 1024.0),
1003            );
1004
1005            let avg_sparsity: f64 =
1006                self.tracked_tensors.values().map(|t| t.stats.sparsity).sum::<f64>()
1007                    / self.tracked_tensors.len() as f64;
1008            stats.insert("avg_sparsity".to_string(), avg_sparsity);
1009
1010            // Add advanced statistics
1011            stats.insert(
1012                "total_dependencies".to_string(),
1013                self.dependencies.len() as f64,
1014            );
1015            stats.insert(
1016                "monitored_tensors".to_string(),
1017                self.time_series.len() as f64,
1018            );
1019            stats.insert(
1020                "active_lifecycles".to_string(),
1021                self.lifecycles.len() as f64,
1022            );
1023        }
1024
1025        stats
1026    }
1027}
1028
1029/// Alert types for tensor issues
1030#[derive(Debug, Clone, Serialize, Deserialize)]
1031pub enum TensorAlertType {
1032    NaNValues,
1033    InfiniteValues,
1034    ExtremeValues,
1035    VanishingGradients,
1036    ExplodingGradients,
1037    MemoryUsage,
1038    ShapeMismatch,
1039}
1040
1041/// Alert severity levels
1042#[derive(Debug, Clone, Serialize, Deserialize)]
1043pub enum AlertSeverity {
1044    Info,
1045    Warning,
1046    Critical,
1047}
1048
1049/// Tensor alert information
1050#[derive(Debug, Clone, Serialize, Deserialize)]
1051pub struct TensorAlert {
1052    pub id: Uuid,
1053    pub tensor_id: Uuid,
1054    pub tensor_name: String,
1055    pub alert_type: TensorAlertType,
1056    pub severity: AlertSeverity,
1057    pub message: String,
1058    pub timestamp: chrono::DateTime<chrono::Utc>,
1059}
1060
1061/// Tensor inspection report
1062#[derive(Debug, Clone, Serialize, Deserialize)]
1063pub struct TensorInspectionReport {
1064    pub total_tensors: usize,
1065    pub tensors_with_issues: usize,
1066    pub total_memory_usage: usize,
1067    pub alerts: Vec<TensorAlert>,
1068    pub comparisons: Vec<TensorComparison>,
1069    pub summary_stats: HashMap<String, f64>,
1070}
1071
1072impl TensorInspectionReport {
1073    pub fn has_nan_values(&self) -> bool {
1074        self.alerts.iter().any(|a| matches!(a.alert_type, TensorAlertType::NaNValues))
1075    }
1076
1077    pub fn has_inf_values(&self) -> bool {
1078        self.alerts
1079            .iter()
1080            .any(|a| matches!(a.alert_type, TensorAlertType::InfiniteValues))
1081    }
1082
1083    pub fn total_nan_count(&self) -> usize {
1084        self.alerts
1085            .iter()
1086            .filter(|a| matches!(a.alert_type, TensorAlertType::NaNValues))
1087            .count()
1088    }
1089
1090    pub fn total_inf_count(&self) -> usize {
1091        self.alerts
1092            .iter()
1093            .filter(|a| matches!(a.alert_type, TensorAlertType::InfiniteValues))
1094            .count()
1095    }
1096
1097    pub fn has_critical_alerts(&self) -> bool {
1098        self.alerts.iter().any(|a| matches!(a.severity, AlertSeverity::Critical))
1099    }
1100}