1use crate::errors::Result;
6use crate::tensor::{DType, Tensor};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::fmt;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum Severity {
16 Info,
18 Warning,
20 Error,
22 Critical,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
28pub enum TensorIssueType {
29 NaN,
31 Infinity,
33 VanishingGradient,
35 ExplodingGradient,
37 AllZeros,
39 UnusualDistribution,
41 MemoryLeak,
43 DTypeMismatch,
45 ShapeMismatch,
47 OperationFailure,
49}
50
51impl fmt::Display for TensorIssueType {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 TensorIssueType::NaN => write!(f, "NaN Values"),
55 TensorIssueType::Infinity => write!(f, "Infinite Values"),
56 TensorIssueType::VanishingGradient => write!(f, "Vanishing Gradient"),
57 TensorIssueType::ExplodingGradient => write!(f, "Exploding Gradient"),
58 TensorIssueType::AllZeros => write!(f, "All Zeros"),
59 TensorIssueType::UnusualDistribution => write!(f, "Unusual Distribution"),
60 TensorIssueType::MemoryLeak => write!(f, "Memory Leak"),
61 TensorIssueType::DTypeMismatch => write!(f, "DType Mismatch"),
62 TensorIssueType::ShapeMismatch => write!(f, "Shape Mismatch"),
63 TensorIssueType::OperationFailure => write!(f, "Operation Failure"),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TensorDebugIssue {
71 pub issue_type: TensorIssueType,
73 pub severity: Severity,
75 pub message: String,
77 pub tensor_name: Option<String>,
79 pub operation: Option<String>,
81 pub location: Option<String>,
83 pub timestamp: std::time::SystemTime,
85 pub metadata: HashMap<String, String>,
87}
88
89impl TensorDebugIssue {
90 fn new(issue_type: TensorIssueType, severity: Severity, message: String) -> Self {
91 Self {
92 issue_type,
93 severity,
94 message,
95 tensor_name: None,
96 operation: None,
97 location: None,
98 timestamp: std::time::SystemTime::now(),
99 metadata: HashMap::new(),
100 }
101 }
102
103 pub fn with_tensor_name(mut self, name: String) -> Self {
105 self.tensor_name = Some(name);
106 self
107 }
108
109 pub fn with_operation(mut self, op: String) -> Self {
111 self.operation = Some(op);
112 self
113 }
114
115 pub fn with_location(mut self, location: String) -> Self {
117 self.location = Some(location);
118 self
119 }
120
121 pub fn with_metadata(mut self, key: String, value: String) -> Self {
123 self.metadata.insert(key, value);
124 self
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DebugTensorStats {
131 pub shape: Vec<usize>,
133 pub dtype: DType,
135 pub min: Option<f64>,
137 pub max: Option<f64>,
139 pub mean: Option<f64>,
141 pub std_dev: Option<f64>,
143 pub nan_count: usize,
145 pub inf_count: usize,
147 pub zero_count: usize,
149 pub total_elements: usize,
151 pub memory_bytes: usize,
153}
154
155impl DebugTensorStats {
156 pub fn from_tensor(tensor: &Tensor) -> Result<Self> {
158 let shape = tensor.shape().to_vec();
159 let dtype = tensor.dtype();
160 let total_elements = shape.iter().product();
161
162 let (min, max, mean, std_dev, nan_count, inf_count, zero_count) = match tensor {
164 Tensor::F32(arr) => {
165 let data: Vec<f32> = arr.iter().copied().collect();
166 let mut min_val = f64::INFINITY;
167 let mut max_val = f64::NEG_INFINITY;
168 let mut sum = 0.0;
169 let mut nan_count = 0;
170 let mut inf_count = 0;
171 let mut zero_count = 0;
172
173 for &val in &data {
174 if val.is_nan() {
175 nan_count += 1;
176 continue;
177 }
178 if val.is_infinite() {
179 inf_count += 1;
180 continue;
181 }
182 if val == 0.0 {
183 zero_count += 1;
184 }
185
186 let val_f64 = val as f64;
187 min_val = min_val.min(val_f64);
188 max_val = max_val.max(val_f64);
189 sum += val_f64;
190 }
191
192 let count = (data.len() - nan_count - inf_count) as f64;
193 let mean = if count > 0.0 { sum / count } else { 0.0 };
194
195 let mut sum_sq_diff = 0.0;
197 for &val in &data {
198 if !val.is_nan() && !val.is_infinite() {
199 let diff = val as f64 - mean;
200 sum_sq_diff += diff * diff;
201 }
202 }
203 let std_dev = if count > 0.0 { (sum_sq_diff / count).sqrt() } else { 0.0 };
204
205 (
206 Some(min_val),
207 Some(max_val),
208 Some(mean),
209 Some(std_dev),
210 nan_count,
211 inf_count,
212 zero_count,
213 )
214 },
215 Tensor::F64(arr) => {
216 let data: Vec<f64> = arr.iter().copied().collect();
217 let mut min_val = f64::INFINITY;
218 let mut max_val = f64::NEG_INFINITY;
219 let mut sum = 0.0;
220 let mut nan_count = 0;
221 let mut inf_count = 0;
222 let mut zero_count = 0;
223
224 for &val in &data {
225 if val.is_nan() {
226 nan_count += 1;
227 continue;
228 }
229 if val.is_infinite() {
230 inf_count += 1;
231 continue;
232 }
233 if val == 0.0 {
234 zero_count += 1;
235 }
236
237 min_val = min_val.min(val);
238 max_val = max_val.max(val);
239 sum += val;
240 }
241
242 let count = (data.len() - nan_count - inf_count) as f64;
243 let mean = if count > 0.0 { sum / count } else { 0.0 };
244
245 let mut sum_sq_diff = 0.0;
247 for &val in &data {
248 if !val.is_nan() && !val.is_infinite() {
249 let diff = val - mean;
250 sum_sq_diff += diff * diff;
251 }
252 }
253 let std_dev = if count > 0.0 { (sum_sq_diff / count).sqrt() } else { 0.0 };
254
255 (
256 Some(min_val),
257 Some(max_val),
258 Some(mean),
259 Some(std_dev),
260 nan_count,
261 inf_count,
262 zero_count,
263 )
264 },
265 _ => (None, None, None, None, 0, 0, 0),
266 };
267
268 let memory_bytes = total_elements * dtype.size_in_bytes();
269
270 Ok(Self {
271 shape,
272 dtype,
273 min,
274 max,
275 mean,
276 std_dev,
277 nan_count,
278 inf_count,
279 zero_count,
280 total_elements,
281 memory_bytes,
282 })
283 }
284
285 pub fn detect_issues(&self) -> Vec<TensorDebugIssue> {
287 let mut issues = Vec::new();
288
289 if self.nan_count > 0 {
291 issues.push(
292 TensorDebugIssue::new(
293 TensorIssueType::NaN,
294 Severity::Error,
295 format!(
296 "Found {} NaN values out of {}",
297 self.nan_count, self.total_elements
298 ),
299 )
300 .with_metadata("nan_count".to_string(), self.nan_count.to_string())
301 .with_metadata(
302 "nan_percentage".to_string(),
303 format!(
304 "{:.2}%",
305 100.0 * self.nan_count as f64 / self.total_elements as f64
306 ),
307 ),
308 );
309 }
310
311 if self.inf_count > 0 {
313 issues.push(
314 TensorDebugIssue::new(
315 TensorIssueType::Infinity,
316 Severity::Error,
317 format!(
318 "Found {} infinite values out of {}",
319 self.inf_count, self.total_elements
320 ),
321 )
322 .with_metadata("inf_count".to_string(), self.inf_count.to_string()),
323 );
324 }
325
326 if self.zero_count == self.total_elements {
328 issues.push(TensorDebugIssue::new(
329 TensorIssueType::AllZeros,
330 Severity::Warning,
331 "Tensor contains all zeros".to_string(),
332 ));
333 }
334
335 if let (Some(max_val), Some(mean_val)) = (self.max, self.mean) {
337 if max_val.abs() < 1e-7 && mean_val.abs() < 1e-7 {
338 issues.push(
339 TensorDebugIssue::new(
340 TensorIssueType::VanishingGradient,
341 Severity::Warning,
342 format!(
343 "Very small values detected (max: {:.2e}, mean: {:.2e})",
344 max_val, mean_val
345 ),
346 )
347 .with_metadata("max_value".to_string(), format!("{:.2e}", max_val))
348 .with_metadata("mean_value".to_string(), format!("{:.2e}", mean_val)),
349 );
350 }
351 }
352
353 if let Some(max_val) = self.max {
355 if max_val.abs() > 1e6 {
356 issues.push(
357 TensorDebugIssue::new(
358 TensorIssueType::ExplodingGradient,
359 Severity::Error,
360 format!("Very large values detected (max: {:.2e})", max_val),
361 )
362 .with_metadata("max_value".to_string(), format!("{:.2e}", max_val)),
363 );
364 }
365 }
366
367 issues
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct OperationTrace {
374 pub operation: String,
376 pub inputs: Vec<String>,
378 pub output: String,
380 pub timestamp: Instant,
382 pub duration: std::time::Duration,
384 pub input_shapes: Vec<Vec<usize>>,
386 pub output_shape: Vec<usize>,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
392pub enum WatchCondition {
393 HasNaN,
395 HasInf,
397 ValueExceeds(f64),
399 ValueBelow(f64),
401 ShapeEquals(Vec<usize>),
403 Custom(String),
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct Watchpoint {
410 pub tensor_pattern: String,
412 pub condition: WatchCondition,
414 pub break_on_trigger: bool,
416 pub trigger_count: usize,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct TensorDebuggerConfig {
423 pub auto_detect_issues: bool,
425 pub enable_tracing: bool,
427 pub max_trace_entries: usize,
429 pub enable_watchpoints: bool,
431 pub break_on_error: bool,
433 pub break_on_warning: bool,
435 pub max_issues: usize,
437}
438
439impl Default for TensorDebuggerConfig {
440 fn default() -> Self {
441 Self {
442 auto_detect_issues: true,
443 enable_tracing: true,
444 max_trace_entries: 1000,
445 enable_watchpoints: true,
446 break_on_error: true,
447 break_on_warning: false,
448 max_issues: 100,
449 }
450 }
451}
452
453pub struct TensorDebugger {
455 config: TensorDebuggerConfig,
456 tensors: Arc<Mutex<HashMap<String, Tensor>>>,
458 issues: Arc<Mutex<VecDeque<TensorDebugIssue>>>,
460 traces: Arc<Mutex<VecDeque<OperationTrace>>>,
462 watchpoints: Arc<Mutex<Vec<Watchpoint>>>,
464 breakpoint_hit: Arc<Mutex<bool>>,
466 stats_cache: Arc<Mutex<HashMap<String, DebugTensorStats>>>,
468}
469
470impl TensorDebugger {
471 pub fn new() -> Self {
473 Self::with_config(TensorDebuggerConfig::default())
474 }
475
476 pub fn with_config(config: TensorDebuggerConfig) -> Self {
478 Self {
479 config,
480 tensors: Arc::new(Mutex::new(HashMap::new())),
481 issues: Arc::new(Mutex::new(VecDeque::new())),
482 traces: Arc::new(Mutex::new(VecDeque::new())),
483 watchpoints: Arc::new(Mutex::new(Vec::new())),
484 breakpoint_hit: Arc::new(Mutex::new(false)),
485 stats_cache: Arc::new(Mutex::new(HashMap::new())),
486 }
487 }
488
489 pub fn register_tensor(&self, name: String, tensor: Tensor) -> Result<()> {
491 let mut tensors = self.tensors.lock().expect("Lock poisoned");
492 tensors.insert(name.clone(), tensor.clone());
493
494 let stats = DebugTensorStats::from_tensor(&tensor)?;
496
497 {
498 let mut cache = self.stats_cache.lock().expect("Lock poisoned");
499 cache.insert(name.clone(), stats.clone());
500 }
501
502 if self.config.auto_detect_issues {
504 let detected_issues = stats.detect_issues();
505 if !detected_issues.is_empty() {
506 let mut issues = self.issues.lock().expect("Lock poisoned");
507 for mut issue in detected_issues {
508 issue = issue.with_tensor_name(name.clone());
509
510 if (issue.severity == Severity::Error && self.config.break_on_error)
512 || (issue.severity == Severity::Warning && self.config.break_on_warning)
513 {
514 *self.breakpoint_hit.lock().expect("Lock poisoned") = true;
515 }
516
517 issues.push_back(issue);
518
519 while issues.len() > self.config.max_issues {
521 issues.pop_front();
522 }
523 }
524 }
525 }
526
527 if self.config.enable_watchpoints {
529 self.check_watchpoints(&name, &tensor)?;
530 }
531
532 Ok(())
533 }
534
535 pub fn add_watchpoint(&self, watchpoint: Watchpoint) {
537 let mut watchpoints = self.watchpoints.lock().expect("Lock poisoned");
538 watchpoints.push(watchpoint);
539 }
540
541 pub fn remove_watchpoint(&self, pattern: &str) {
543 let mut watchpoints = self.watchpoints.lock().expect("Lock poisoned");
544 watchpoints.retain(|w| w.tensor_pattern != pattern);
545 }
546
547 fn check_watchpoints(&self, name: &str, tensor: &Tensor) -> Result<()> {
549 let mut watchpoints = self.watchpoints.lock().expect("Lock poisoned");
550
551 for wp in watchpoints.iter_mut() {
552 if name == wp.tensor_pattern || wp.tensor_pattern == "*" {
554 let triggered = match &wp.condition {
555 WatchCondition::HasNaN => {
556 let stats = DebugTensorStats::from_tensor(tensor)?;
557 stats.nan_count > 0
558 },
559 WatchCondition::HasInf => {
560 let stats = DebugTensorStats::from_tensor(tensor)?;
561 stats.inf_count > 0
562 },
563 WatchCondition::ValueExceeds(threshold) => {
564 let stats = DebugTensorStats::from_tensor(tensor)?;
565 stats.max.is_some_and(|max| max.abs() > *threshold)
566 },
567 WatchCondition::ValueBelow(threshold) => {
568 let stats = DebugTensorStats::from_tensor(tensor)?;
569 stats.min.is_some_and(|min| min.abs() < *threshold)
570 },
571 WatchCondition::ShapeEquals(expected_shape) => {
572 tensor.shape() == expected_shape.as_slice()
573 },
574 WatchCondition::Custom(_) => false, };
576
577 if triggered {
578 wp.trigger_count += 1;
579
580 if wp.break_on_trigger {
581 *self.breakpoint_hit.lock().expect("Lock poisoned") = true;
582 }
583
584 let issue = TensorDebugIssue::new(
586 TensorIssueType::OperationFailure,
587 Severity::Warning,
588 format!("Watchpoint triggered: {:?}", wp.condition),
589 )
590 .with_tensor_name(name.to_string())
591 .with_metadata("trigger_count".to_string(), wp.trigger_count.to_string());
592
593 let mut issues = self.issues.lock().expect("Lock poisoned");
594 issues.push_back(issue);
595 }
596 }
597 }
598
599 Ok(())
600 }
601
602 pub fn get_tensor(&self, name: &str) -> Option<Tensor> {
604 let tensors = self.tensors.lock().expect("Lock poisoned");
605 tensors.get(name).cloned()
606 }
607
608 pub fn get_stats(&self, name: &str) -> Option<DebugTensorStats> {
610 let cache = self.stats_cache.lock().expect("Lock poisoned");
611 cache.get(name).cloned()
612 }
613
614 pub fn get_issues(&self) -> Vec<TensorDebugIssue> {
616 let issues = self.issues.lock().expect("Lock poisoned");
617 issues.iter().cloned().collect()
618 }
619
620 pub fn clear_issues(&self) {
622 let mut issues = self.issues.lock().expect("Lock poisoned");
623 issues.clear();
624 }
625
626 pub fn get_traces(&self) -> Vec<OperationTrace> {
628 let traces = self.traces.lock().expect("Lock poisoned");
629 traces.iter().cloned().collect()
630 }
631
632 pub fn clear_traces(&self) {
634 let mut traces = self.traces.lock().expect("Lock poisoned");
635 traces.clear();
636 }
637
638 pub fn is_breakpoint_hit(&self) -> bool {
640 *self.breakpoint_hit.lock().expect("Lock poisoned")
641 }
642
643 pub fn clear_breakpoint(&self) {
645 *self.breakpoint_hit.lock().expect("Lock poisoned") = false;
646 }
647
648 pub fn print_summary(&self) {
650 println!("\n=== Tensor Debugger Summary ===\n");
651
652 let cache = self.stats_cache.lock().expect("Lock poisoned");
653 println!("Tracked Tensors: {}", cache.len());
654
655 for (name, stats) in cache.iter() {
656 println!("\nTensor: {}", name);
657 println!(" Shape: {:?}", stats.shape);
658 println!(" DType: {:?}", stats.dtype);
659 println!(" Elements: {}", stats.total_elements);
660 println!(" Memory: {} bytes", stats.memory_bytes);
661
662 if let Some(min) = stats.min {
663 println!(" Min: {:.6}", min);
664 }
665 if let Some(max) = stats.max {
666 println!(" Max: {:.6}", max);
667 }
668 if let Some(mean) = stats.mean {
669 println!(" Mean: {:.6}", mean);
670 }
671 if let Some(std) = stats.std_dev {
672 println!(" Std Dev: {:.6}", std);
673 }
674
675 if stats.nan_count > 0 {
676 println!(" ⚠️ NaN count: {}", stats.nan_count);
677 }
678 if stats.inf_count > 0 {
679 println!(" ⚠️ Inf count: {}", stats.inf_count);
680 }
681 }
682
683 let issues = self.issues.lock().expect("Lock poisoned");
684 if !issues.is_empty() {
685 println!("\n=== Issues ({}) ===\n", issues.len());
686 for (i, issue) in issues.iter().enumerate() {
687 println!(
688 "{}. [{:?}] {}: {}",
689 i + 1,
690 issue.severity,
691 issue.issue_type,
692 issue.message
693 );
694 if let Some(tensor_name) = &issue.tensor_name {
695 println!(" Tensor: {}", tensor_name);
696 }
697 }
698 }
699
700 println!("\n==============================\n");
701 }
702}
703
704impl Default for TensorDebugger {
705 fn default() -> Self {
706 Self::new()
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713
714 #[test]
715 fn test_debugger_creation() {
716 let debugger = TensorDebugger::new();
717 assert!(!debugger.is_breakpoint_hit());
718 assert_eq!(debugger.get_issues().len(), 0);
719 }
720
721 #[test]
722 fn test_tensor_registration() -> Result<()> {
723 let debugger = TensorDebugger::new();
724 let tensor = Tensor::ones(&[2, 3])?;
725
726 debugger.register_tensor("test_tensor".to_string(), tensor.clone())?;
727
728 let retrieved = debugger.get_tensor("test_tensor");
729 assert!(retrieved.is_some());
730
731 let stats = debugger.get_stats("test_tensor");
732 assert!(stats.is_some());
733
734 let stats = stats.expect("operation failed in test");
735 assert_eq!(stats.shape, vec![2, 3]);
736 assert_eq!(stats.total_elements, 6);
737
738 Ok(())
739 }
740
741 #[test]
742 fn test_nan_detection() -> Result<()> {
743 let debugger = TensorDebugger::new();
744
745 let data = vec![1.0, 2.0, f32::NAN, 4.0];
747 let tensor = Tensor::from_slice(&data, &[4])?;
748
749 debugger.register_tensor("nan_tensor".to_string(), tensor)?;
750
751 let issues = debugger.get_issues();
752 assert!(!issues.is_empty());
753
754 let has_nan_issue = issues.iter().any(|i| i.issue_type == TensorIssueType::NaN);
755 assert!(has_nan_issue);
756
757 Ok(())
758 }
759
760 #[test]
761 fn test_watchpoint() -> Result<()> {
762 let debugger = TensorDebugger::new();
763
764 let wp = Watchpoint {
766 tensor_pattern: "watched".to_string(),
767 condition: WatchCondition::HasNaN,
768 break_on_trigger: true,
769 trigger_count: 0,
770 };
771 debugger.add_watchpoint(wp);
772
773 let data = vec![1.0, f32::NAN];
775 let tensor = Tensor::from_slice(&data, &[2])?;
776
777 debugger.register_tensor("watched".to_string(), tensor)?;
778
779 assert!(debugger.is_breakpoint_hit());
781
782 Ok(())
783 }
784}