Skip to main content

trustformers_core/
tensor_debugger.rs

1/// Interactive Tensor Debugger for TrustformeRS
2///
3/// This module provides comprehensive debugging tools for tensor operations,
4/// gradient flow analysis, and interactive debugging features.
5use crate::errors::Result;
6use crate::tensor::{DType, Tensor};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::fmt;
10use std::sync::{Arc, Mutex};
11use std::time::Instant;
12
13/// Severity level for debugger issues
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum Severity {
16    /// Informational message
17    Info,
18    /// Warning about potential issues
19    Warning,
20    /// Error that should be addressed
21    Error,
22    /// Critical issue requiring immediate attention
23    Critical,
24}
25
26/// Type of issue detected by the tensor debugger
27#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
28pub enum TensorIssueType {
29    /// NaN values detected
30    NaN,
31    /// Infinite values detected
32    Infinity,
33    /// Gradient vanishing (very small values)
34    VanishingGradient,
35    /// Gradient exploding (very large values)
36    ExplodingGradient,
37    /// All zeros in tensor
38    AllZeros,
39    /// Unusual value distribution
40    UnusualDistribution,
41    /// Memory leak suspected
42    MemoryLeak,
43    /// Dtype mismatch
44    DTypeMismatch,
45    /// Shape mismatch
46    ShapeMismatch,
47    /// Operation failure
48    OperationFailure,
49}
50
51impl fmt::Display for TensorIssueType {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            TensorIssueType::NaN => write!(f, "NaN Values"),
55            TensorIssueType::Infinity => write!(f, "Infinite Values"),
56            TensorIssueType::VanishingGradient => write!(f, "Vanishing Gradient"),
57            TensorIssueType::ExplodingGradient => write!(f, "Exploding Gradient"),
58            TensorIssueType::AllZeros => write!(f, "All Zeros"),
59            TensorIssueType::UnusualDistribution => write!(f, "Unusual Distribution"),
60            TensorIssueType::MemoryLeak => write!(f, "Memory Leak"),
61            TensorIssueType::DTypeMismatch => write!(f, "DType Mismatch"),
62            TensorIssueType::ShapeMismatch => write!(f, "Shape Mismatch"),
63            TensorIssueType::OperationFailure => write!(f, "Operation Failure"),
64        }
65    }
66}
67
68/// Issue detected by the tensor debugger
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TensorDebugIssue {
71    /// Type of issue
72    pub issue_type: TensorIssueType,
73    /// Severity level
74    pub severity: Severity,
75    /// Human-readable message
76    pub message: String,
77    /// Tensor name (if available)
78    pub tensor_name: Option<String>,
79    /// Operation name (if available)
80    pub operation: Option<String>,
81    /// Location in code (file:line)
82    pub location: Option<String>,
83    /// Timestamp when issue was detected
84    pub timestamp: std::time::SystemTime,
85    /// Additional metadata
86    pub metadata: HashMap<String, String>,
87}
88
89impl TensorDebugIssue {
90    fn new(issue_type: TensorIssueType, severity: Severity, message: String) -> Self {
91        Self {
92            issue_type,
93            severity,
94            message,
95            tensor_name: None,
96            operation: None,
97            location: None,
98            timestamp: std::time::SystemTime::now(),
99            metadata: HashMap::new(),
100        }
101    }
102
103    /// Add tensor name to issue
104    pub fn with_tensor_name(mut self, name: String) -> Self {
105        self.tensor_name = Some(name);
106        self
107    }
108
109    /// Add operation name to issue
110    pub fn with_operation(mut self, op: String) -> Self {
111        self.operation = Some(op);
112        self
113    }
114
115    /// Add source location
116    pub fn with_location(mut self, location: String) -> Self {
117        self.location = Some(location);
118        self
119    }
120
121    /// Add metadata
122    pub fn with_metadata(mut self, key: String, value: String) -> Self {
123        self.metadata.insert(key, value);
124        self
125    }
126}
127
128/// Statistics for a tensor (debugger)
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct DebugTensorStats {
131    /// Tensor shape
132    pub shape: Vec<usize>,
133    /// Data type
134    pub dtype: DType,
135    /// Minimum value (if applicable)
136    pub min: Option<f64>,
137    /// Maximum value (if applicable)
138    pub max: Option<f64>,
139    /// Mean value (if applicable)
140    pub mean: Option<f64>,
141    /// Standard deviation (if applicable)
142    pub std_dev: Option<f64>,
143    /// Number of NaN values
144    pub nan_count: usize,
145    /// Number of infinite values
146    pub inf_count: usize,
147    /// Number of zero values
148    pub zero_count: usize,
149    /// Total number of elements
150    pub total_elements: usize,
151    /// Memory usage in bytes
152    pub memory_bytes: usize,
153}
154
155impl DebugTensorStats {
156    /// Compute statistics from a tensor
157    pub fn from_tensor(tensor: &Tensor) -> Result<Self> {
158        let shape = tensor.shape().to_vec();
159        let dtype = tensor.dtype();
160        let total_elements = shape.iter().product();
161
162        // For simplicity, only compute full stats for F32 tensors
163        let (min, max, mean, std_dev, nan_count, inf_count, zero_count) = match tensor {
164            Tensor::F32(arr) => {
165                let data: Vec<f32> = arr.iter().copied().collect();
166                let mut min_val = f64::INFINITY;
167                let mut max_val = f64::NEG_INFINITY;
168                let mut sum = 0.0;
169                let mut nan_count = 0;
170                let mut inf_count = 0;
171                let mut zero_count = 0;
172
173                for &val in &data {
174                    if val.is_nan() {
175                        nan_count += 1;
176                        continue;
177                    }
178                    if val.is_infinite() {
179                        inf_count += 1;
180                        continue;
181                    }
182                    if val == 0.0 {
183                        zero_count += 1;
184                    }
185
186                    let val_f64 = val as f64;
187                    min_val = min_val.min(val_f64);
188                    max_val = max_val.max(val_f64);
189                    sum += val_f64;
190                }
191
192                let count = (data.len() - nan_count - inf_count) as f64;
193                let mean = if count > 0.0 { sum / count } else { 0.0 };
194
195                // Compute std dev
196                let mut sum_sq_diff = 0.0;
197                for &val in &data {
198                    if !val.is_nan() && !val.is_infinite() {
199                        let diff = val as f64 - mean;
200                        sum_sq_diff += diff * diff;
201                    }
202                }
203                let std_dev = if count > 0.0 { (sum_sq_diff / count).sqrt() } else { 0.0 };
204
205                (
206                    Some(min_val),
207                    Some(max_val),
208                    Some(mean),
209                    Some(std_dev),
210                    nan_count,
211                    inf_count,
212                    zero_count,
213                )
214            },
215            Tensor::F64(arr) => {
216                let data: Vec<f64> = arr.iter().copied().collect();
217                let mut min_val = f64::INFINITY;
218                let mut max_val = f64::NEG_INFINITY;
219                let mut sum = 0.0;
220                let mut nan_count = 0;
221                let mut inf_count = 0;
222                let mut zero_count = 0;
223
224                for &val in &data {
225                    if val.is_nan() {
226                        nan_count += 1;
227                        continue;
228                    }
229                    if val.is_infinite() {
230                        inf_count += 1;
231                        continue;
232                    }
233                    if val == 0.0 {
234                        zero_count += 1;
235                    }
236
237                    min_val = min_val.min(val);
238                    max_val = max_val.max(val);
239                    sum += val;
240                }
241
242                let count = (data.len() - nan_count - inf_count) as f64;
243                let mean = if count > 0.0 { sum / count } else { 0.0 };
244
245                // Compute std dev
246                let mut sum_sq_diff = 0.0;
247                for &val in &data {
248                    if !val.is_nan() && !val.is_infinite() {
249                        let diff = val - mean;
250                        sum_sq_diff += diff * diff;
251                    }
252                }
253                let std_dev = if count > 0.0 { (sum_sq_diff / count).sqrt() } else { 0.0 };
254
255                (
256                    Some(min_val),
257                    Some(max_val),
258                    Some(mean),
259                    Some(std_dev),
260                    nan_count,
261                    inf_count,
262                    zero_count,
263                )
264            },
265            _ => (None, None, None, None, 0, 0, 0),
266        };
267
268        let memory_bytes = total_elements * dtype.size_in_bytes();
269
270        Ok(Self {
271            shape,
272            dtype,
273            min,
274            max,
275            mean,
276            std_dev,
277            nan_count,
278            inf_count,
279            zero_count,
280            total_elements,
281            memory_bytes,
282        })
283    }
284
285    /// Check for potential issues
286    pub fn detect_issues(&self) -> Vec<TensorDebugIssue> {
287        let mut issues = Vec::new();
288
289        // Check for NaN values
290        if self.nan_count > 0 {
291            issues.push(
292                TensorDebugIssue::new(
293                    TensorIssueType::NaN,
294                    Severity::Error,
295                    format!(
296                        "Found {} NaN values out of {}",
297                        self.nan_count, self.total_elements
298                    ),
299                )
300                .with_metadata("nan_count".to_string(), self.nan_count.to_string())
301                .with_metadata(
302                    "nan_percentage".to_string(),
303                    format!(
304                        "{:.2}%",
305                        100.0 * self.nan_count as f64 / self.total_elements as f64
306                    ),
307                ),
308            );
309        }
310
311        // Check for infinite values
312        if self.inf_count > 0 {
313            issues.push(
314                TensorDebugIssue::new(
315                    TensorIssueType::Infinity,
316                    Severity::Error,
317                    format!(
318                        "Found {} infinite values out of {}",
319                        self.inf_count, self.total_elements
320                    ),
321                )
322                .with_metadata("inf_count".to_string(), self.inf_count.to_string()),
323            );
324        }
325
326        // Check for all zeros
327        if self.zero_count == self.total_elements {
328            issues.push(TensorDebugIssue::new(
329                TensorIssueType::AllZeros,
330                Severity::Warning,
331                "Tensor contains all zeros".to_string(),
332            ));
333        }
334
335        // Check for vanishing values (very small)
336        if let (Some(max_val), Some(mean_val)) = (self.max, self.mean) {
337            if max_val.abs() < 1e-7 && mean_val.abs() < 1e-7 {
338                issues.push(
339                    TensorDebugIssue::new(
340                        TensorIssueType::VanishingGradient,
341                        Severity::Warning,
342                        format!(
343                            "Very small values detected (max: {:.2e}, mean: {:.2e})",
344                            max_val, mean_val
345                        ),
346                    )
347                    .with_metadata("max_value".to_string(), format!("{:.2e}", max_val))
348                    .with_metadata("mean_value".to_string(), format!("{:.2e}", mean_val)),
349                );
350            }
351        }
352
353        // Check for exploding values (very large)
354        if let Some(max_val) = self.max {
355            if max_val.abs() > 1e6 {
356                issues.push(
357                    TensorDebugIssue::new(
358                        TensorIssueType::ExplodingGradient,
359                        Severity::Error,
360                        format!("Very large values detected (max: {:.2e})", max_val),
361                    )
362                    .with_metadata("max_value".to_string(), format!("{:.2e}", max_val)),
363                );
364            }
365        }
366
367        issues
368    }
369}
370
371/// Operation trace entry
372#[derive(Debug, Clone)]
373pub struct OperationTrace {
374    /// Operation name
375    pub operation: String,
376    /// Input tensor names
377    pub inputs: Vec<String>,
378    /// Output tensor name
379    pub output: String,
380    /// Timestamp (not serializable)
381    pub timestamp: Instant,
382    /// Duration
383    pub duration: std::time::Duration,
384    /// Input shapes
385    pub input_shapes: Vec<Vec<usize>>,
386    /// Output shape
387    pub output_shape: Vec<usize>,
388}
389
390/// Watchpoint condition
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub enum WatchCondition {
393    /// Watch for NaN values
394    HasNaN,
395    /// Watch for infinite values
396    HasInf,
397    /// Watch for values exceeding threshold
398    ValueExceeds(f64),
399    /// Watch for values below threshold
400    ValueBelow(f64),
401    /// Watch for specific shape
402    ShapeEquals(Vec<usize>),
403    /// Custom condition (string description)
404    Custom(String),
405}
406
407/// Watchpoint on a tensor
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct Watchpoint {
410    /// Tensor name pattern (supports wildcards)
411    pub tensor_pattern: String,
412    /// Condition to watch for
413    pub condition: WatchCondition,
414    /// Whether to break on condition
415    pub break_on_trigger: bool,
416    /// Number of times triggered
417    pub trigger_count: usize,
418}
419
420/// Tensor debugger configuration
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct TensorDebuggerConfig {
423    /// Enable automatic issue detection
424    pub auto_detect_issues: bool,
425    /// Enable operation tracing
426    pub enable_tracing: bool,
427    /// Maximum number of trace entries to keep
428    pub max_trace_entries: usize,
429    /// Enable watchpoints
430    pub enable_watchpoints: bool,
431    /// Break on errors
432    pub break_on_error: bool,
433    /// Break on warnings
434    pub break_on_warning: bool,
435    /// Maximum number of issues to track
436    pub max_issues: usize,
437}
438
439impl Default for TensorDebuggerConfig {
440    fn default() -> Self {
441        Self {
442            auto_detect_issues: true,
443            enable_tracing: true,
444            max_trace_entries: 1000,
445            enable_watchpoints: true,
446            break_on_error: true,
447            break_on_warning: false,
448            max_issues: 100,
449        }
450    }
451}
452
453/// Interactive tensor debugger
454pub struct TensorDebugger {
455    config: TensorDebuggerConfig,
456    /// Named tensors being tracked
457    tensors: Arc<Mutex<HashMap<String, Tensor>>>,
458    /// Detected issues
459    issues: Arc<Mutex<VecDeque<TensorDebugIssue>>>,
460    /// Operation traces
461    traces: Arc<Mutex<VecDeque<OperationTrace>>>,
462    /// Active watchpoints
463    watchpoints: Arc<Mutex<Vec<Watchpoint>>>,
464    /// Breakpoint flag
465    breakpoint_hit: Arc<Mutex<bool>>,
466    /// Statistics cache
467    stats_cache: Arc<Mutex<HashMap<String, DebugTensorStats>>>,
468}
469
470impl TensorDebugger {
471    /// Create a new tensor debugger with default configuration
472    pub fn new() -> Self {
473        Self::with_config(TensorDebuggerConfig::default())
474    }
475
476    /// Create a new tensor debugger with custom configuration
477    pub fn with_config(config: TensorDebuggerConfig) -> Self {
478        Self {
479            config,
480            tensors: Arc::new(Mutex::new(HashMap::new())),
481            issues: Arc::new(Mutex::new(VecDeque::new())),
482            traces: Arc::new(Mutex::new(VecDeque::new())),
483            watchpoints: Arc::new(Mutex::new(Vec::new())),
484            breakpoint_hit: Arc::new(Mutex::new(false)),
485            stats_cache: Arc::new(Mutex::new(HashMap::new())),
486        }
487    }
488
489    /// Register a tensor for debugging
490    pub fn register_tensor(&self, name: String, tensor: Tensor) -> Result<()> {
491        let mut tensors = self.tensors.lock().expect("Lock poisoned");
492        tensors.insert(name.clone(), tensor.clone());
493
494        // Compute and cache statistics
495        let stats = DebugTensorStats::from_tensor(&tensor)?;
496
497        {
498            let mut cache = self.stats_cache.lock().expect("Lock poisoned");
499            cache.insert(name.clone(), stats.clone());
500        }
501
502        // Auto-detect issues if enabled
503        if self.config.auto_detect_issues {
504            let detected_issues = stats.detect_issues();
505            if !detected_issues.is_empty() {
506                let mut issues = self.issues.lock().expect("Lock poisoned");
507                for mut issue in detected_issues {
508                    issue = issue.with_tensor_name(name.clone());
509
510                    // Check if we should break
511                    if (issue.severity == Severity::Error && self.config.break_on_error)
512                        || (issue.severity == Severity::Warning && self.config.break_on_warning)
513                    {
514                        *self.breakpoint_hit.lock().expect("Lock poisoned") = true;
515                    }
516
517                    issues.push_back(issue);
518
519                    // Limit issue queue
520                    while issues.len() > self.config.max_issues {
521                        issues.pop_front();
522                    }
523                }
524            }
525        }
526
527        // Check watchpoints if enabled
528        if self.config.enable_watchpoints {
529            self.check_watchpoints(&name, &tensor)?;
530        }
531
532        Ok(())
533    }
534
535    /// Add a watchpoint
536    pub fn add_watchpoint(&self, watchpoint: Watchpoint) {
537        let mut watchpoints = self.watchpoints.lock().expect("Lock poisoned");
538        watchpoints.push(watchpoint);
539    }
540
541    /// Remove all watchpoints matching pattern
542    pub fn remove_watchpoint(&self, pattern: &str) {
543        let mut watchpoints = self.watchpoints.lock().expect("Lock poisoned");
544        watchpoints.retain(|w| w.tensor_pattern != pattern);
545    }
546
547    /// Check watchpoints for a tensor
548    fn check_watchpoints(&self, name: &str, tensor: &Tensor) -> Result<()> {
549        let mut watchpoints = self.watchpoints.lock().expect("Lock poisoned");
550
551        for wp in watchpoints.iter_mut() {
552            // Simple pattern matching (exact match for now)
553            if name == wp.tensor_pattern || wp.tensor_pattern == "*" {
554                let triggered = match &wp.condition {
555                    WatchCondition::HasNaN => {
556                        let stats = DebugTensorStats::from_tensor(tensor)?;
557                        stats.nan_count > 0
558                    },
559                    WatchCondition::HasInf => {
560                        let stats = DebugTensorStats::from_tensor(tensor)?;
561                        stats.inf_count > 0
562                    },
563                    WatchCondition::ValueExceeds(threshold) => {
564                        let stats = DebugTensorStats::from_tensor(tensor)?;
565                        stats.max.is_some_and(|max| max.abs() > *threshold)
566                    },
567                    WatchCondition::ValueBelow(threshold) => {
568                        let stats = DebugTensorStats::from_tensor(tensor)?;
569                        stats.min.is_some_and(|min| min.abs() < *threshold)
570                    },
571                    WatchCondition::ShapeEquals(expected_shape) => {
572                        tensor.shape() == expected_shape.as_slice()
573                    },
574                    WatchCondition::Custom(_) => false, // Custom conditions not implemented
575                };
576
577                if triggered {
578                    wp.trigger_count += 1;
579
580                    if wp.break_on_trigger {
581                        *self.breakpoint_hit.lock().expect("Lock poisoned") = true;
582                    }
583
584                    // Log issue
585                    let issue = TensorDebugIssue::new(
586                        TensorIssueType::OperationFailure,
587                        Severity::Warning,
588                        format!("Watchpoint triggered: {:?}", wp.condition),
589                    )
590                    .with_tensor_name(name.to_string())
591                    .with_metadata("trigger_count".to_string(), wp.trigger_count.to_string());
592
593                    let mut issues = self.issues.lock().expect("Lock poisoned");
594                    issues.push_back(issue);
595                }
596            }
597        }
598
599        Ok(())
600    }
601
602    /// Get tensor by name
603    pub fn get_tensor(&self, name: &str) -> Option<Tensor> {
604        let tensors = self.tensors.lock().expect("Lock poisoned");
605        tensors.get(name).cloned()
606    }
607
608    /// Get statistics for a tensor
609    pub fn get_stats(&self, name: &str) -> Option<DebugTensorStats> {
610        let cache = self.stats_cache.lock().expect("Lock poisoned");
611        cache.get(name).cloned()
612    }
613
614    /// Get all issues
615    pub fn get_issues(&self) -> Vec<TensorDebugIssue> {
616        let issues = self.issues.lock().expect("Lock poisoned");
617        issues.iter().cloned().collect()
618    }
619
620    /// Clear all issues
621    pub fn clear_issues(&self) {
622        let mut issues = self.issues.lock().expect("Lock poisoned");
623        issues.clear();
624    }
625
626    /// Get operation traces
627    pub fn get_traces(&self) -> Vec<OperationTrace> {
628        let traces = self.traces.lock().expect("Lock poisoned");
629        traces.iter().cloned().collect()
630    }
631
632    /// Clear traces
633    pub fn clear_traces(&self) {
634        let mut traces = self.traces.lock().expect("Lock poisoned");
635        traces.clear();
636    }
637
638    /// Check if breakpoint was hit
639    pub fn is_breakpoint_hit(&self) -> bool {
640        *self.breakpoint_hit.lock().expect("Lock poisoned")
641    }
642
643    /// Clear breakpoint flag
644    pub fn clear_breakpoint(&self) {
645        *self.breakpoint_hit.lock().expect("Lock poisoned") = false;
646    }
647
648    /// Print summary of all tracked tensors
649    pub fn print_summary(&self) {
650        println!("\n=== Tensor Debugger Summary ===\n");
651
652        let cache = self.stats_cache.lock().expect("Lock poisoned");
653        println!("Tracked Tensors: {}", cache.len());
654
655        for (name, stats) in cache.iter() {
656            println!("\nTensor: {}", name);
657            println!("  Shape: {:?}", stats.shape);
658            println!("  DType: {:?}", stats.dtype);
659            println!("  Elements: {}", stats.total_elements);
660            println!("  Memory: {} bytes", stats.memory_bytes);
661
662            if let Some(min) = stats.min {
663                println!("  Min: {:.6}", min);
664            }
665            if let Some(max) = stats.max {
666                println!("  Max: {:.6}", max);
667            }
668            if let Some(mean) = stats.mean {
669                println!("  Mean: {:.6}", mean);
670            }
671            if let Some(std) = stats.std_dev {
672                println!("  Std Dev: {:.6}", std);
673            }
674
675            if stats.nan_count > 0 {
676                println!("  ⚠️  NaN count: {}", stats.nan_count);
677            }
678            if stats.inf_count > 0 {
679                println!("  ⚠️  Inf count: {}", stats.inf_count);
680            }
681        }
682
683        let issues = self.issues.lock().expect("Lock poisoned");
684        if !issues.is_empty() {
685            println!("\n=== Issues ({}) ===\n", issues.len());
686            for (i, issue) in issues.iter().enumerate() {
687                println!(
688                    "{}. [{:?}] {}: {}",
689                    i + 1,
690                    issue.severity,
691                    issue.issue_type,
692                    issue.message
693                );
694                if let Some(tensor_name) = &issue.tensor_name {
695                    println!("   Tensor: {}", tensor_name);
696                }
697            }
698        }
699
700        println!("\n==============================\n");
701    }
702}
703
704impl Default for TensorDebugger {
705    fn default() -> Self {
706        Self::new()
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713
714    #[test]
715    fn test_debugger_creation() {
716        let debugger = TensorDebugger::new();
717        assert!(!debugger.is_breakpoint_hit());
718        assert_eq!(debugger.get_issues().len(), 0);
719    }
720
721    #[test]
722    fn test_tensor_registration() -> Result<()> {
723        let debugger = TensorDebugger::new();
724        let tensor = Tensor::ones(&[2, 3])?;
725
726        debugger.register_tensor("test_tensor".to_string(), tensor.clone())?;
727
728        let retrieved = debugger.get_tensor("test_tensor");
729        assert!(retrieved.is_some());
730
731        let stats = debugger.get_stats("test_tensor");
732        assert!(stats.is_some());
733
734        let stats = stats.expect("operation failed in test");
735        assert_eq!(stats.shape, vec![2, 3]);
736        assert_eq!(stats.total_elements, 6);
737
738        Ok(())
739    }
740
741    #[test]
742    fn test_nan_detection() -> Result<()> {
743        let debugger = TensorDebugger::new();
744
745        // Create tensor with NaN
746        let data = vec![1.0, 2.0, f32::NAN, 4.0];
747        let tensor = Tensor::from_slice(&data, &[4])?;
748
749        debugger.register_tensor("nan_tensor".to_string(), tensor)?;
750
751        let issues = debugger.get_issues();
752        assert!(!issues.is_empty());
753
754        let has_nan_issue = issues.iter().any(|i| i.issue_type == TensorIssueType::NaN);
755        assert!(has_nan_issue);
756
757        Ok(())
758    }
759
760    #[test]
761    fn test_watchpoint() -> Result<()> {
762        let debugger = TensorDebugger::new();
763
764        // Add watchpoint for NaN
765        let wp = Watchpoint {
766            tensor_pattern: "watched".to_string(),
767            condition: WatchCondition::HasNaN,
768            break_on_trigger: true,
769            trigger_count: 0,
770        };
771        debugger.add_watchpoint(wp);
772
773        // Create tensor with NaN
774        let data = vec![1.0, f32::NAN];
775        let tensor = Tensor::from_slice(&data, &[2])?;
776
777        debugger.register_tensor("watched".to_string(), tensor)?;
778
779        // Breakpoint should be hit
780        assert!(debugger.is_breakpoint_hit());
781
782        Ok(())
783    }
784}