Skip to main content

tensorlogic_infer/
debug.rs

1//! Debugging utilities for execution tracing and tensor inspection.
2//!
3//! This module provides comprehensive debugging tools for TensorLogic execution:
4//!
5//! - **ExecutionTracer**: Record execution flow through computation graphs
6//! - **TensorInspector**: Examine intermediate tensor values and statistics
7//! - **BreakpointManager**: Pause execution at specific nodes for inspection
8//! - **ExecutionRecorder**: Record full execution history for replay and analysis
9//!
10//! # Example
11//!
12//! ```rust
13//! use tensorlogic_infer::debug::{ExecutionTracer, TensorInspector, BreakpointManager};
14//!
15//! // Set up tracing
16//! let mut tracer = ExecutionTracer::new();
17//! tracer.enable();
18//!
19//! // Add breakpoints
20//! let mut breakpoints = BreakpointManager::new();
21//! breakpoints.add_node_breakpoint(5);
22//!
23//! // Execute with debugging
24//! // ... execution code ...
25//!
26//! // Analyze trace
27//! let trace = tracer.get_trace();
28//! for entry in trace.entries() {
29//!     println!("Node {}: {}ms", entry.node_id, entry.duration_ms());
30//! }
31//! ```
32
33use std::collections::HashMap;
34use std::fmt;
35use std::time::{Duration, Instant};
36
37/// Execution trace entry recording a single operation.
38#[derive(Debug, Clone)]
39pub struct TraceEntry {
40    /// Unique entry ID
41    pub entry_id: usize,
42    /// Node ID in the computation graph
43    pub node_id: usize,
44    /// Operation name
45    pub operation: String,
46    /// Start timestamp
47    pub start_time: Instant,
48    /// Duration of execution
49    pub duration: Duration,
50    /// Input tensor IDs
51    pub input_ids: Vec<usize>,
52    /// Output tensor IDs
53    pub output_ids: Vec<usize>,
54    /// Additional metadata
55    pub metadata: HashMap<String, String>,
56}
57
58impl TraceEntry {
59    /// Get duration in milliseconds.
60    pub fn duration_ms(&self) -> f64 {
61        self.duration.as_secs_f64() * 1000.0
62    }
63
64    /// Get duration in microseconds.
65    pub fn duration_us(&self) -> f64 {
66        self.duration.as_secs_f64() * 1_000_000.0
67    }
68}
69
70/// Execution trace containing recorded operations.
71#[derive(Debug, Clone)]
72pub struct ExecutionTrace {
73    entries: Vec<TraceEntry>,
74    total_duration: Duration,
75    graph_id: Option<usize>,
76}
77
78impl ExecutionTrace {
79    /// Create a new empty trace.
80    pub fn new() -> Self {
81        Self {
82            entries: Vec::new(),
83            total_duration: Duration::ZERO,
84            graph_id: None,
85        }
86    }
87
88    /// Set the graph ID for this trace.
89    pub fn with_graph_id(mut self, graph_id: usize) -> Self {
90        self.graph_id = Some(graph_id);
91        self
92    }
93
94    /// Add a trace entry.
95    pub fn add_entry(&mut self, entry: TraceEntry) {
96        self.total_duration += entry.duration;
97        self.entries.push(entry);
98    }
99
100    /// Get all trace entries.
101    pub fn entries(&self) -> &[TraceEntry] {
102        &self.entries
103    }
104
105    /// Get the total execution duration.
106    pub fn total_duration(&self) -> Duration {
107        self.total_duration
108    }
109
110    /// Get the total duration in milliseconds.
111    pub fn total_duration_ms(&self) -> f64 {
112        self.total_duration.as_secs_f64() * 1000.0
113    }
114
115    /// Get entries for a specific node.
116    pub fn entries_for_node(&self, node_id: usize) -> Vec<&TraceEntry> {
117        self.entries
118            .iter()
119            .filter(|e| e.node_id == node_id)
120            .collect()
121    }
122
123    /// Get the critical path (longest chain of dependent operations).
124    ///
125    /// Uses dependency tracking via input/output tensor IDs to compute the critical path.
126    /// The critical path is the longest chain of operations where each depends on the previous,
127    /// determining the minimum possible execution time.
128    pub fn critical_path(&self) -> Vec<&TraceEntry> {
129        if self.entries.is_empty() {
130            return Vec::new();
131        }
132
133        let n = self.entries.len();
134
135        // Map output tensors to the entries that produce them
136        let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
137        for (idx, entry) in self.entries.iter().enumerate() {
138            for &output_id in &entry.output_ids {
139                tensor_producers.insert(output_id, idx);
140            }
141        }
142
143        // Build reverse dependencies: for each entry, which entries must complete before it
144        let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
145        for (idx, entry) in self.entries.iter().enumerate() {
146            for &input_id in &entry.input_ids {
147                if let Some(&producer_idx) = tensor_producers.get(&input_id) {
148                    if producer_idx < n {
149                        predecessors[idx].push(producer_idx);
150                    }
151                }
152            }
153        }
154
155        // Compute earliest start time (EST) and earliest finish time (EFT) for each node
156        // EFT[i] = max(EFT[pred] for pred in predecessors[i]) + duration[i]
157        let mut eft = vec![Duration::ZERO; n];
158        let mut predecessor_on_critical_path = vec![None; n];
159
160        // Use iterative approach with fixed-point iteration to handle all dependencies
161        let mut changed = true;
162        for _ in 0..n {
163            // Maximum n iterations needed
164            if !changed {
165                break;
166            }
167            changed = false;
168
169            for idx in 0..n {
170                let mut max_pred_eft = Duration::ZERO;
171                let mut critical_pred = None;
172
173                for &pred_idx in &predecessors[idx] {
174                    if eft[pred_idx] > max_pred_eft {
175                        max_pred_eft = eft[pred_idx];
176                        critical_pred = Some(pred_idx);
177                    }
178                }
179
180                let new_eft = max_pred_eft + self.entries[idx].duration;
181                if new_eft > eft[idx] {
182                    eft[idx] = new_eft;
183                    predecessor_on_critical_path[idx] = critical_pred;
184                    changed = true;
185                }
186            }
187        }
188
189        // Find the node with maximum EFT (end of critical path)
190        let critical_end_idx = eft
191            .iter()
192            .enumerate()
193            .max_by_key(|(_, &time)| time)
194            .map(|(idx, _)| idx)
195            .unwrap_or(0);
196
197        // Backtrack from the end to find the critical path
198        let mut critical_path_indices = Vec::new();
199        let mut current = Some(critical_end_idx);
200
201        while let Some(idx) = current {
202            critical_path_indices.push(idx);
203            current = predecessor_on_critical_path[idx];
204        }
205
206        // Reverse to get path from start to end
207        critical_path_indices.reverse();
208
209        // Convert indices to entry references
210        critical_path_indices
211            .iter()
212            .map(|&idx| &self.entries[idx])
213            .collect()
214    }
215
216    /// Get the total critical path duration.
217    pub fn critical_path_duration(&self) -> Duration {
218        self.critical_path().iter().map(|e| e.duration).sum()
219    }
220
221    /// Calculate the parallelism factor (total work / critical path time).
222    /// Values > 1 indicate potential for parallel execution.
223    pub fn parallelism_factor(&self) -> f64 {
224        let critical_time = self.critical_path_duration();
225        if critical_time.as_secs_f64() == 0.0 {
226            return 1.0;
227        }
228        self.total_duration.as_secs_f64() / critical_time.as_secs_f64()
229    }
230
231    /// Get operations sorted by duration (slowest first).
232    pub fn slowest_operations(&self, limit: usize) -> Vec<&TraceEntry> {
233        let mut sorted: Vec<_> = self.entries.iter().collect();
234        sorted.sort_by(|a, b| b.duration.cmp(&a.duration));
235        sorted.into_iter().take(limit).collect()
236    }
237
238    /// Generate a summary report.
239    pub fn summary(&self) -> TraceSummary {
240        TraceSummary::from_trace(self)
241    }
242}
243
244impl Default for ExecutionTrace {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250/// Summary statistics for an execution trace.
251#[derive(Debug, Clone)]
252pub struct TraceSummary {
253    /// Total number of operations
254    pub total_operations: usize,
255    /// Total execution time in milliseconds
256    pub total_time_ms: f64,
257    /// Average operation time in milliseconds
258    pub avg_time_ms: f64,
259    /// Slowest operation time in milliseconds
260    pub max_time_ms: f64,
261    /// Fastest operation time in milliseconds
262    pub min_time_ms: f64,
263    /// Operation counts by type
264    pub operation_counts: HashMap<String, usize>,
265}
266
267impl TraceSummary {
268    /// Create a summary from a trace.
269    pub fn from_trace(trace: &ExecutionTrace) -> Self {
270        let entries = trace.entries();
271        let total_operations = entries.len();
272
273        let total_time_ms = trace.total_duration_ms();
274        let avg_time_ms = if total_operations > 0 {
275            total_time_ms / total_operations as f64
276        } else {
277            0.0
278        };
279
280        let max_time_ms = entries.iter().map(|e| e.duration_ms()).fold(0.0, f64::max);
281        let min_time_ms = entries
282            .iter()
283            .map(|e| e.duration_ms())
284            .fold(f64::MAX, f64::min);
285
286        let mut operation_counts: HashMap<String, usize> = HashMap::new();
287        for entry in entries {
288            *operation_counts.entry(entry.operation.clone()).or_insert(0) += 1;
289        }
290
291        Self {
292            total_operations,
293            total_time_ms,
294            avg_time_ms,
295            max_time_ms,
296            min_time_ms,
297            operation_counts,
298        }
299    }
300}
301
302impl fmt::Display for TraceSummary {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        writeln!(f, "Execution Trace Summary")?;
305        writeln!(f, "=======================")?;
306        writeln!(f, "Total operations: {}", self.total_operations)?;
307        writeln!(f, "Total time: {:.2} ms", self.total_time_ms)?;
308        writeln!(f, "Average time: {:.2} ms", self.avg_time_ms)?;
309        writeln!(f, "Max time: {:.2} ms", self.max_time_ms)?;
310        writeln!(f, "Min time: {:.2} ms", self.min_time_ms)?;
311        writeln!(f, "\nOperation Counts:")?;
312        let mut sorted_ops: Vec<_> = self.operation_counts.iter().collect();
313        sorted_ops.sort_by_key(|(_, count)| std::cmp::Reverse(**count));
314        for (op, count) in sorted_ops {
315            writeln!(f, "  {}: {}", op, count)?;
316        }
317        Ok(())
318    }
319}
320
321/// Execution tracer for recording operation flow.
322pub struct ExecutionTracer {
323    enabled: bool,
324    current_trace: ExecutionTrace,
325    traces: Vec<ExecutionTrace>,
326    next_entry_id: usize,
327}
328
329impl ExecutionTracer {
330    /// Create a new execution tracer.
331    pub fn new() -> Self {
332        Self {
333            enabled: false,
334            current_trace: ExecutionTrace::new(),
335            traces: Vec::new(),
336            next_entry_id: 0,
337        }
338    }
339
340    /// Enable tracing.
341    pub fn enable(&mut self) {
342        self.enabled = true;
343    }
344
345    /// Disable tracing.
346    pub fn disable(&mut self) {
347        self.enabled = false;
348    }
349
350    /// Check if tracing is enabled.
351    pub fn is_enabled(&self) -> bool {
352        self.enabled
353    }
354
355    /// Start a new trace (finalizes current trace if any).
356    pub fn start_trace(&mut self, graph_id: Option<usize>) {
357        if !self.current_trace.entries.is_empty() {
358            self.finalize_trace();
359        }
360        self.current_trace = ExecutionTrace::new();
361        if let Some(id) = graph_id {
362            self.current_trace.graph_id = Some(id);
363        }
364    }
365
366    /// Finalize the current trace and store it.
367    pub fn finalize_trace(&mut self) {
368        if !self.current_trace.entries.is_empty() {
369            let trace = std::mem::take(&mut self.current_trace);
370            self.traces.push(trace);
371        }
372    }
373
374    /// Record the start of an operation.
375    pub fn record_operation_start(
376        &mut self,
377        _node_id: usize,
378        _operation: impl Into<String>,
379        _input_ids: Vec<usize>,
380    ) -> OperationHandle {
381        if !self.enabled {
382            return OperationHandle {
383                entry_id: None,
384                start_time: Instant::now(),
385            };
386        }
387
388        let entry_id = self.next_entry_id;
389        self.next_entry_id += 1;
390
391        OperationHandle {
392            entry_id: Some(entry_id),
393            start_time: Instant::now(),
394        }
395    }
396
397    /// Record the end of an operation.
398    pub fn record_operation_end(
399        &mut self,
400        handle: OperationHandle,
401        node_id: usize,
402        operation: impl Into<String>,
403        input_ids: Vec<usize>,
404        output_ids: Vec<usize>,
405        metadata: HashMap<String, String>,
406    ) {
407        if !self.enabled || handle.entry_id.is_none() {
408            return;
409        }
410
411        let duration = handle.start_time.elapsed();
412        let entry = TraceEntry {
413            entry_id: handle.entry_id.unwrap(),
414            node_id,
415            operation: operation.into(),
416            start_time: handle.start_time,
417            duration,
418            input_ids,
419            output_ids,
420            metadata,
421        };
422
423        self.current_trace.add_entry(entry);
424    }
425
426    /// Get the current trace.
427    pub fn get_trace(&self) -> &ExecutionTrace {
428        &self.current_trace
429    }
430
431    /// Get all recorded traces.
432    pub fn get_all_traces(&self) -> &[ExecutionTrace] {
433        &self.traces
434    }
435
436    /// Clear all traces.
437    pub fn clear(&mut self) {
438        self.current_trace = ExecutionTrace::new();
439        self.traces.clear();
440        self.next_entry_id = 0;
441    }
442}
443
444impl Default for ExecutionTracer {
445    fn default() -> Self {
446        Self::new()
447    }
448}
449
450/// Handle for an in-progress operation recording.
451pub struct OperationHandle {
452    entry_id: Option<usize>,
453    start_time: Instant,
454}
455
456/// Tensor statistics for inspection.
457#[derive(Debug, Clone)]
458pub struct TensorStats {
459    /// Tensor ID
460    pub tensor_id: usize,
461    /// Shape of the tensor
462    pub shape: Vec<usize>,
463    /// Number of elements
464    pub num_elements: usize,
465    /// Data type
466    pub dtype: String,
467    /// Minimum value (if computed)
468    pub min_value: Option<f64>,
469    /// Maximum value (if computed)
470    pub max_value: Option<f64>,
471    /// Mean value (if computed)
472    pub mean_value: Option<f64>,
473    /// Standard deviation (if computed)
474    pub std_dev: Option<f64>,
475    /// Number of NaN values
476    pub num_nans: Option<usize>,
477    /// Number of infinite values
478    pub num_infs: Option<usize>,
479}
480
481impl TensorStats {
482    /// Create tensor stats with basic information.
483    pub fn new(tensor_id: usize, shape: Vec<usize>, dtype: impl Into<String>) -> Self {
484        let num_elements = shape.iter().product();
485        Self {
486            tensor_id,
487            shape,
488            num_elements,
489            dtype: dtype.into(),
490            min_value: None,
491            max_value: None,
492            mean_value: None,
493            std_dev: None,
494            num_nans: None,
495            num_infs: None,
496        }
497    }
498
499    /// Add computed statistics.
500    pub fn with_statistics(
501        mut self,
502        min: f64,
503        max: f64,
504        mean: f64,
505        std_dev: f64,
506        num_nans: usize,
507        num_infs: usize,
508    ) -> Self {
509        self.min_value = Some(min);
510        self.max_value = Some(max);
511        self.mean_value = Some(mean);
512        self.std_dev = Some(std_dev);
513        self.num_nans = Some(num_nans);
514        self.num_infs = Some(num_infs);
515        self
516    }
517
518    /// Check if the tensor has numerical issues.
519    pub fn has_numerical_issues(&self) -> bool {
520        self.num_nans.unwrap_or(0) > 0 || self.num_infs.unwrap_or(0) > 0
521    }
522}
523
524impl fmt::Display for TensorStats {
525    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
526        writeln!(f, "Tensor {} Stats:", self.tensor_id)?;
527        writeln!(f, "  Shape: {:?}", self.shape)?;
528        writeln!(f, "  Elements: {}", self.num_elements)?;
529        writeln!(f, "  DType: {}", self.dtype)?;
530        if let Some(min) = self.min_value {
531            writeln!(f, "  Min: {:.6}", min)?;
532        }
533        if let Some(max) = self.max_value {
534            writeln!(f, "  Max: {:.6}", max)?;
535        }
536        if let Some(mean) = self.mean_value {
537            writeln!(f, "  Mean: {:.6}", mean)?;
538        }
539        if let Some(std) = self.std_dev {
540            writeln!(f, "  Std Dev: {:.6}", std)?;
541        }
542        if let Some(nans) = self.num_nans {
543            if nans > 0 {
544                writeln!(f, "  ⚠️  NaNs: {}", nans)?;
545            }
546        }
547        if let Some(infs) = self.num_infs {
548            if infs > 0 {
549                writeln!(f, "  ⚠️  Infs: {}", infs)?;
550            }
551        }
552        Ok(())
553    }
554}
555
556/// Tensor inspector for examining intermediate values.
557pub struct TensorInspector {
558    enabled: bool,
559    tensor_stats: HashMap<usize, TensorStats>,
560    watch_list: Vec<usize>,
561}
562
563impl TensorInspector {
564    /// Create a new tensor inspector.
565    pub fn new() -> Self {
566        Self {
567            enabled: false,
568            tensor_stats: HashMap::new(),
569            watch_list: Vec::new(),
570        }
571    }
572
573    /// Enable inspection.
574    pub fn enable(&mut self) {
575        self.enabled = true;
576    }
577
578    /// Disable inspection.
579    pub fn disable(&mut self) {
580        self.enabled = false;
581    }
582
583    /// Check if inspection is enabled.
584    pub fn is_enabled(&self) -> bool {
585        self.enabled
586    }
587
588    /// Add a tensor to the watch list.
589    pub fn watch(&mut self, tensor_id: usize) {
590        if !self.watch_list.contains(&tensor_id) {
591            self.watch_list.push(tensor_id);
592        }
593    }
594
595    /// Remove a tensor from the watch list.
596    pub fn unwatch(&mut self, tensor_id: usize) {
597        self.watch_list.retain(|&id| id != tensor_id);
598    }
599
600    /// Clear the watch list.
601    pub fn clear_watch_list(&mut self) {
602        self.watch_list.clear();
603    }
604
605    /// Check if a tensor should be inspected.
606    pub fn should_inspect(&self, tensor_id: usize) -> bool {
607        self.enabled && (self.watch_list.is_empty() || self.watch_list.contains(&tensor_id))
608    }
609
610    /// Record tensor statistics.
611    pub fn record_stats(&mut self, stats: TensorStats) {
612        if !self.enabled {
613            return;
614        }
615        self.tensor_stats.insert(stats.tensor_id, stats);
616    }
617
618    /// Get statistics for a tensor.
619    pub fn get_stats(&self, tensor_id: usize) -> Option<&TensorStats> {
620        self.tensor_stats.get(&tensor_id)
621    }
622
623    /// Get all recorded statistics.
624    pub fn get_all_stats(&self) -> &HashMap<usize, TensorStats> {
625        &self.tensor_stats
626    }
627
628    /// Find tensors with numerical issues.
629    pub fn find_problematic_tensors(&self) -> Vec<&TensorStats> {
630        self.tensor_stats
631            .values()
632            .filter(|stats| stats.has_numerical_issues())
633            .collect()
634    }
635
636    /// Clear all recorded statistics.
637    pub fn clear(&mut self) {
638        self.tensor_stats.clear();
639    }
640}
641
642impl Default for TensorInspector {
643    fn default() -> Self {
644        Self::new()
645    }
646}
647
648/// Breakpoint type for execution control.
649#[derive(Debug, Clone, PartialEq, Eq)]
650pub enum Breakpoint {
651    /// Break at a specific node
652    Node(usize),
653    /// Break at a specific operation type
654    Operation(String),
655    /// Break when a tensor has numerical issues
656    NumericalIssue,
657    /// Break when execution time exceeds threshold (in microseconds)
658    TimeThreshold(u64),
659    /// Conditional breakpoint with custom predicate
660    Conditional(String), // Store predicate as string for simplicity
661}
662
663/// Breakpoint hit information.
664#[derive(Debug, Clone)]
665pub struct BreakpointHit {
666    /// The breakpoint that was hit
667    pub breakpoint: Breakpoint,
668    /// Node ID where execution paused
669    pub node_id: usize,
670    /// Current execution time in microseconds
671    pub elapsed_us: u64,
672    /// Additional context
673    pub context: HashMap<String, String>,
674}
675
676/// Manager for execution breakpoints.
677pub struct BreakpointManager {
678    enabled: bool,
679    breakpoints: Vec<Breakpoint>,
680    hits: Vec<BreakpointHit>,
681    continue_execution: bool,
682}
683
684impl BreakpointManager {
685    /// Create a new breakpoint manager.
686    pub fn new() -> Self {
687        Self {
688            enabled: false,
689            breakpoints: Vec::new(),
690            hits: Vec::new(),
691            continue_execution: true,
692        }
693    }
694
695    /// Enable breakpoint checking.
696    pub fn enable(&mut self) {
697        self.enabled = true;
698    }
699
700    /// Disable breakpoint checking.
701    pub fn disable(&mut self) {
702        self.enabled = false;
703    }
704
705    /// Check if breakpoint checking is enabled.
706    pub fn is_enabled(&self) -> bool {
707        self.enabled
708    }
709
710    /// Add a node breakpoint.
711    pub fn add_node_breakpoint(&mut self, node_id: usize) {
712        self.breakpoints.push(Breakpoint::Node(node_id));
713    }
714
715    /// Add an operation breakpoint.
716    pub fn add_operation_breakpoint(&mut self, operation: impl Into<String>) {
717        self.breakpoints
718            .push(Breakpoint::Operation(operation.into()));
719    }
720
721    /// Add a numerical issue breakpoint.
722    pub fn add_numerical_issue_breakpoint(&mut self) {
723        self.breakpoints.push(Breakpoint::NumericalIssue);
724    }
725
726    /// Add a time threshold breakpoint.
727    pub fn add_time_threshold_breakpoint(&mut self, threshold_us: u64) {
728        self.breakpoints
729            .push(Breakpoint::TimeThreshold(threshold_us));
730    }
731
732    /// Remove a breakpoint.
733    pub fn remove_breakpoint(&mut self, breakpoint: &Breakpoint) {
734        self.breakpoints.retain(|bp| bp != breakpoint);
735    }
736
737    /// Clear all breakpoints.
738    pub fn clear_breakpoints(&mut self) {
739        self.breakpoints.clear();
740    }
741
742    /// Get all breakpoints.
743    pub fn get_breakpoints(&self) -> &[Breakpoint] {
744        &self.breakpoints
745    }
746
747    /// Check if execution should break at this point.
748    pub fn should_break(
749        &mut self,
750        node_id: usize,
751        operation: &str,
752        elapsed_us: u64,
753        has_numerical_issue: bool,
754    ) -> Option<BreakpointHit> {
755        if !self.enabled || !self.continue_execution {
756            return None;
757        }
758
759        for breakpoint in &self.breakpoints {
760            let should_break = match breakpoint {
761                Breakpoint::Node(bp_node_id) => *bp_node_id == node_id,
762                Breakpoint::Operation(bp_op) => bp_op == operation,
763                Breakpoint::NumericalIssue => has_numerical_issue,
764                Breakpoint::TimeThreshold(threshold) => elapsed_us > *threshold,
765                Breakpoint::Conditional(_) => false, // Not implemented yet
766            };
767
768            if should_break {
769                let hit = BreakpointHit {
770                    breakpoint: breakpoint.clone(),
771                    node_id,
772                    elapsed_us,
773                    context: HashMap::new(),
774                };
775                self.hits.push(hit.clone());
776                self.continue_execution = false;
777                return Some(hit);
778            }
779        }
780
781        None
782    }
783
784    /// Continue execution after a breakpoint hit.
785    pub fn continue_execution(&mut self) {
786        self.continue_execution = true;
787    }
788
789    /// Get all breakpoint hits.
790    pub fn get_hits(&self) -> &[BreakpointHit] {
791        &self.hits
792    }
793
794    /// Clear all breakpoint hits.
795    pub fn clear_hits(&mut self) {
796        self.hits.clear();
797    }
798}
799
800impl Default for BreakpointManager {
801    fn default() -> Self {
802        Self::new()
803    }
804}
805
806/// Full execution recorder for replay and analysis.
807pub struct ExecutionRecorder {
808    enabled: bool,
809    tracer: ExecutionTracer,
810    inspector: TensorInspector,
811    breakpoints: BreakpointManager,
812}
813
814impl ExecutionRecorder {
815    /// Create a new execution recorder.
816    pub fn new() -> Self {
817        Self {
818            enabled: false,
819            tracer: ExecutionTracer::new(),
820            inspector: TensorInspector::new(),
821            breakpoints: BreakpointManager::new(),
822        }
823    }
824
825    /// Enable all recording features.
826    pub fn enable(&mut self) {
827        self.enabled = true;
828        self.tracer.enable();
829        self.inspector.enable();
830        self.breakpoints.enable();
831    }
832
833    /// Disable all recording features.
834    pub fn disable(&mut self) {
835        self.enabled = false;
836        self.tracer.disable();
837        self.inspector.disable();
838        self.breakpoints.disable();
839    }
840
841    /// Get the tracer.
842    pub fn tracer(&mut self) -> &mut ExecutionTracer {
843        &mut self.tracer
844    }
845
846    /// Get the inspector.
847    pub fn inspector(&mut self) -> &mut TensorInspector {
848        &mut self.inspector
849    }
850
851    /// Get the breakpoint manager.
852    pub fn breakpoints(&mut self) -> &mut BreakpointManager {
853        &mut self.breakpoints
854    }
855
856    /// Clear all recorded data.
857    pub fn clear(&mut self) {
858        self.tracer.clear();
859        self.inspector.clear();
860        self.breakpoints.clear_hits();
861    }
862
863    /// Generate a comprehensive execution report.
864    pub fn generate_report(&self) -> ExecutionReport {
865        ExecutionReport {
866            trace_summary: self.tracer.get_trace().summary(),
867            problematic_tensors: self.inspector.find_problematic_tensors().len(),
868            breakpoint_hits: self.breakpoints.get_hits().len(),
869        }
870    }
871}
872
873impl Default for ExecutionRecorder {
874    fn default() -> Self {
875        Self::new()
876    }
877}
878
879/// Comprehensive execution report.
880#[derive(Debug, Clone)]
881pub struct ExecutionReport {
882    /// Trace summary
883    pub trace_summary: TraceSummary,
884    /// Number of tensors with numerical issues
885    pub problematic_tensors: usize,
886    /// Number of breakpoint hits
887    pub breakpoint_hits: usize,
888}
889
890impl fmt::Display for ExecutionReport {
891    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
892        writeln!(f, "{}", self.trace_summary)?;
893        writeln!(f, "\nDebug Information:")?;
894        writeln!(f, "  Problematic tensors: {}", self.problematic_tensors)?;
895        writeln!(f, "  Breakpoint hits: {}", self.breakpoint_hits)?;
896        Ok(())
897    }
898}
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903
904    #[test]
905    fn test_execution_tracer() {
906        let mut tracer = ExecutionTracer::new();
907        assert!(!tracer.is_enabled());
908
909        tracer.enable();
910        assert!(tracer.is_enabled());
911
912        tracer.start_trace(Some(1));
913        let handle = tracer.record_operation_start(0, "einsum", vec![0, 1]);
914        std::thread::sleep(Duration::from_millis(10));
915        tracer.record_operation_end(handle, 0, "einsum", vec![0, 1], vec![2], HashMap::new());
916
917        let trace = tracer.get_trace();
918        assert_eq!(trace.entries().len(), 1);
919        assert!(trace.total_duration_ms() >= 10.0);
920    }
921
922    #[test]
923    fn test_trace_summary() {
924        let mut trace = ExecutionTrace::new();
925        let entry = TraceEntry {
926            entry_id: 0,
927            node_id: 0,
928            operation: "einsum".to_string(),
929            start_time: Instant::now(),
930            duration: Duration::from_millis(10),
931            input_ids: vec![0],
932            output_ids: vec![1],
933            metadata: HashMap::new(),
934        };
935        trace.add_entry(entry);
936
937        let summary = trace.summary();
938        assert_eq!(summary.total_operations, 1);
939        assert!(summary.total_time_ms >= 10.0);
940    }
941
942    #[test]
943    fn test_tensor_inspector() {
944        let mut inspector = TensorInspector::new();
945        inspector.enable();
946
947        let stats =
948            TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
949
950        inspector.record_stats(stats.clone());
951        assert_eq!(inspector.get_stats(0).unwrap().tensor_id, 0);
952        assert!(!stats.has_numerical_issues());
953    }
954
955    #[test]
956    fn test_tensor_numerical_issues() {
957        let stats = TensorStats::new(0, vec![2, 3], "f64").with_statistics(
958            0.0,
959            f64::INFINITY,
960            0.5,
961            0.25,
962            1,
963            1,
964        );
965
966        assert!(stats.has_numerical_issues());
967    }
968
969    #[test]
970    fn test_breakpoint_manager() {
971        let mut manager = BreakpointManager::new();
972        manager.enable();
973        manager.add_node_breakpoint(5);
974
975        let hit = manager.should_break(5, "einsum", 1000, false);
976        assert!(hit.is_some());
977        assert_eq!(hit.unwrap().node_id, 5);
978
979        manager.continue_execution();
980        let hit2 = manager.should_break(5, "einsum", 1000, false);
981        assert!(hit2.is_some());
982    }
983
984    #[test]
985    fn test_operation_breakpoint() {
986        let mut manager = BreakpointManager::new();
987        manager.enable();
988        manager.add_operation_breakpoint("matmul");
989
990        let hit = manager.should_break(1, "matmul", 1000, false);
991        assert!(hit.is_some());
992
993        let no_hit = manager.should_break(2, "add", 1000, false);
994        assert!(no_hit.is_none());
995    }
996
997    #[test]
998    fn test_time_threshold_breakpoint() {
999        let mut manager = BreakpointManager::new();
1000        manager.enable();
1001        manager.add_time_threshold_breakpoint(5000);
1002
1003        let no_hit = manager.should_break(1, "op", 4000, false);
1004        assert!(no_hit.is_none());
1005
1006        let hit = manager.should_break(1, "op", 6000, false);
1007        assert!(hit.is_some());
1008    }
1009
1010    #[test]
1011    fn test_numerical_issue_breakpoint() {
1012        let mut manager = BreakpointManager::new();
1013        manager.enable();
1014        manager.add_numerical_issue_breakpoint();
1015
1016        let no_hit = manager.should_break(1, "op", 1000, false);
1017        assert!(no_hit.is_none());
1018
1019        let hit = manager.should_break(1, "op", 1000, true);
1020        assert!(hit.is_some());
1021    }
1022
1023    #[test]
1024    fn test_execution_recorder() {
1025        let mut recorder = ExecutionRecorder::new();
1026        recorder.enable();
1027
1028        assert!(recorder.tracer().is_enabled());
1029        assert!(recorder.inspector().is_enabled());
1030        assert!(recorder.breakpoints().is_enabled());
1031
1032        recorder.clear();
1033        let report = recorder.generate_report();
1034        assert_eq!(report.trace_summary.total_operations, 0);
1035    }
1036
1037    #[test]
1038    fn test_slowest_operations() {
1039        let mut trace = ExecutionTrace::new();
1040        for i in 0..5 {
1041            let entry = TraceEntry {
1042                entry_id: i,
1043                node_id: i,
1044                operation: format!("op{}", i),
1045                start_time: Instant::now(),
1046                duration: Duration::from_millis((i as u64 + 1) * 10),
1047                input_ids: vec![],
1048                output_ids: vec![],
1049                metadata: HashMap::new(),
1050            };
1051            trace.add_entry(entry);
1052        }
1053
1054        let slowest = trace.slowest_operations(3);
1055        assert_eq!(slowest.len(), 3);
1056        assert_eq!(slowest[0].node_id, 4); // Slowest
1057        assert_eq!(slowest[1].node_id, 3);
1058        assert_eq!(slowest[2].node_id, 2);
1059    }
1060
1061    #[test]
1062    fn test_watch_list() {
1063        let mut inspector = TensorInspector::new();
1064        inspector.enable();
1065
1066        inspector.watch(1);
1067        inspector.watch(2);
1068
1069        assert!(inspector.should_inspect(1));
1070        assert!(inspector.should_inspect(2));
1071        assert!(!inspector.should_inspect(3));
1072
1073        inspector.unwatch(1);
1074        assert!(!inspector.should_inspect(1));
1075        assert!(inspector.should_inspect(2));
1076
1077        inspector.clear_watch_list();
1078        // When watch list is empty, all tensors should be inspected
1079        assert!(inspector.should_inspect(5));
1080    }
1081
1082    #[test]
1083    fn test_trace_entries_for_node() {
1084        let mut trace = ExecutionTrace::new();
1085        for i in 0..3 {
1086            let entry = TraceEntry {
1087                entry_id: i,
1088                node_id: i % 2,
1089                operation: "op".to_string(),
1090                start_time: Instant::now(),
1091                duration: Duration::from_millis(10),
1092                input_ids: vec![],
1093                output_ids: vec![],
1094                metadata: HashMap::new(),
1095            };
1096            trace.add_entry(entry);
1097        }
1098
1099        let node_0_entries = trace.entries_for_node(0);
1100        assert_eq!(node_0_entries.len(), 2);
1101
1102        let node_1_entries = trace.entries_for_node(1);
1103        assert_eq!(node_1_entries.len(), 1);
1104    }
1105
1106    #[test]
1107    fn test_critical_path_linear_chain() {
1108        // Test linear chain: op0 -> op1 -> op2
1109        let mut trace = ExecutionTrace::new();
1110
1111        // op0: produces tensor 0
1112        trace.add_entry(TraceEntry {
1113            entry_id: 0,
1114            node_id: 0,
1115            operation: "op0".to_string(),
1116            start_time: Instant::now(),
1117            duration: Duration::from_millis(10),
1118            input_ids: vec![],
1119            output_ids: vec![0],
1120            metadata: HashMap::new(),
1121        });
1122
1123        // op1: consumes tensor 0, produces tensor 1
1124        trace.add_entry(TraceEntry {
1125            entry_id: 1,
1126            node_id: 1,
1127            operation: "op1".to_string(),
1128            start_time: Instant::now(),
1129            duration: Duration::from_millis(20),
1130            input_ids: vec![0],
1131            output_ids: vec![1],
1132            metadata: HashMap::new(),
1133        });
1134
1135        // op2: consumes tensor 1, produces tensor 2
1136        trace.add_entry(TraceEntry {
1137            entry_id: 2,
1138            node_id: 2,
1139            operation: "op2".to_string(),
1140            start_time: Instant::now(),
1141            duration: Duration::from_millis(15),
1142            input_ids: vec![1],
1143            output_ids: vec![2],
1144            metadata: HashMap::new(),
1145        });
1146
1147        let critical_path = trace.critical_path();
1148        assert_eq!(critical_path.len(), 3); // All operations are on critical path
1149        assert_eq!(critical_path[0].node_id, 0);
1150        assert_eq!(critical_path[1].node_id, 1);
1151        assert_eq!(critical_path[2].node_id, 2);
1152
1153        let cp_duration = trace.critical_path_duration();
1154        assert_eq!(cp_duration, Duration::from_millis(45)); // 10 + 20 + 15
1155    }
1156
1157    #[test]
1158    fn test_critical_path_parallel_operations() {
1159        // Test parallel operations: op0 produces two independent outputs
1160        let mut trace = ExecutionTrace::new();
1161
1162        // op0: produces tensors 0 and 1
1163        trace.add_entry(TraceEntry {
1164            entry_id: 0,
1165            node_id: 0,
1166            operation: "op0".to_string(),
1167            start_time: Instant::now(),
1168            duration: Duration::from_millis(10),
1169            input_ids: vec![],
1170            output_ids: vec![0, 1],
1171            metadata: HashMap::new(),
1172        });
1173
1174        // op1: consumes tensor 0 (fast path)
1175        trace.add_entry(TraceEntry {
1176            entry_id: 1,
1177            node_id: 1,
1178            operation: "op1".to_string(),
1179            start_time: Instant::now(),
1180            duration: Duration::from_millis(5),
1181            input_ids: vec![0],
1182            output_ids: vec![2],
1183            metadata: HashMap::new(),
1184        });
1185
1186        // op2: consumes tensor 1 (slow path - critical)
1187        trace.add_entry(TraceEntry {
1188            entry_id: 2,
1189            node_id: 2,
1190            operation: "op2".to_string(),
1191            start_time: Instant::now(),
1192            duration: Duration::from_millis(20),
1193            input_ids: vec![1],
1194            output_ids: vec![3],
1195            metadata: HashMap::new(),
1196        });
1197
1198        let critical_path = trace.critical_path();
1199        // Critical path should be op0 -> op2 (the longer path)
1200        assert_eq!(critical_path.len(), 2);
1201        assert_eq!(critical_path[0].node_id, 0);
1202        assert_eq!(critical_path[1].node_id, 2);
1203
1204        let cp_duration = trace.critical_path_duration();
1205        assert_eq!(cp_duration, Duration::from_millis(30)); // 10 + 20
1206    }
1207
1208    #[test]
1209    fn test_critical_path_diamond_pattern() {
1210        // Test diamond pattern: op0 -> (op1, op2) -> op3
1211        let mut trace = ExecutionTrace::new();
1212
1213        // op0: produces tensor 0
1214        trace.add_entry(TraceEntry {
1215            entry_id: 0,
1216            node_id: 0,
1217            operation: "op0".to_string(),
1218            start_time: Instant::now(),
1219            duration: Duration::from_millis(10),
1220            input_ids: vec![],
1221            output_ids: vec![0],
1222            metadata: HashMap::new(),
1223        });
1224
1225        // op1: consumes tensor 0, produces tensor 1 (fast path)
1226        trace.add_entry(TraceEntry {
1227            entry_id: 1,
1228            node_id: 1,
1229            operation: "op1".to_string(),
1230            start_time: Instant::now(),
1231            duration: Duration::from_millis(5),
1232            input_ids: vec![0],
1233            output_ids: vec![1],
1234            metadata: HashMap::new(),
1235        });
1236
1237        // op2: consumes tensor 0, produces tensor 2 (slow path)
1238        trace.add_entry(TraceEntry {
1239            entry_id: 2,
1240            node_id: 2,
1241            operation: "op2".to_string(),
1242            start_time: Instant::now(),
1243            duration: Duration::from_millis(25),
1244            input_ids: vec![0],
1245            output_ids: vec![2],
1246            metadata: HashMap::new(),
1247        });
1248
1249        // op3: consumes tensors 1 and 2, produces tensor 3
1250        trace.add_entry(TraceEntry {
1251            entry_id: 3,
1252            node_id: 3,
1253            operation: "op3".to_string(),
1254            start_time: Instant::now(),
1255            duration: Duration::from_millis(15),
1256            input_ids: vec![1, 2],
1257            output_ids: vec![3],
1258            metadata: HashMap::new(),
1259        });
1260
1261        let critical_path = trace.critical_path();
1262        // Critical path should be op0 -> op2 -> op3 (longest path)
1263        assert_eq!(critical_path.len(), 3);
1264        assert_eq!(critical_path[0].node_id, 0);
1265        assert_eq!(critical_path[1].node_id, 2);
1266        assert_eq!(critical_path[2].node_id, 3);
1267
1268        let cp_duration = trace.critical_path_duration();
1269        assert_eq!(cp_duration, Duration::from_millis(50)); // 10 + 25 + 15
1270    }
1271
1272    #[test]
1273    fn test_critical_path_empty() {
1274        let trace = ExecutionTrace::new();
1275        let critical_path = trace.critical_path();
1276        assert_eq!(critical_path.len(), 0);
1277        assert_eq!(trace.critical_path_duration(), Duration::ZERO);
1278    }
1279
1280    #[test]
1281    fn test_critical_path_single_operation() {
1282        let mut trace = ExecutionTrace::new();
1283        trace.add_entry(TraceEntry {
1284            entry_id: 0,
1285            node_id: 0,
1286            operation: "op0".to_string(),
1287            start_time: Instant::now(),
1288            duration: Duration::from_millis(10),
1289            input_ids: vec![],
1290            output_ids: vec![0],
1291            metadata: HashMap::new(),
1292        });
1293
1294        let critical_path = trace.critical_path();
1295        assert_eq!(critical_path.len(), 1);
1296        assert_eq!(critical_path[0].node_id, 0);
1297    }
1298
1299    #[test]
1300    fn test_parallelism_factor() {
1301        let mut trace = ExecutionTrace::new();
1302
1303        // Create a scenario with some parallelism
1304        // Total work: 10 + 20 + 30 + 40 = 100ms
1305        // Critical path: 10 + 30 + 40 = 80ms
1306        // Parallelism factor: 100 / 80 = 1.25
1307
1308        // op0: start (10ms)
1309        trace.add_entry(TraceEntry {
1310            entry_id: 0,
1311            node_id: 0,
1312            operation: "op0".to_string(),
1313            start_time: Instant::now(),
1314            duration: Duration::from_millis(10),
1315            input_ids: vec![],
1316            output_ids: vec![0],
1317            metadata: HashMap::new(),
1318        });
1319
1320        // op1: parallel to op2 (20ms, fast path)
1321        trace.add_entry(TraceEntry {
1322            entry_id: 1,
1323            node_id: 1,
1324            operation: "op1".to_string(),
1325            start_time: Instant::now(),
1326            duration: Duration::from_millis(20),
1327            input_ids: vec![0],
1328            output_ids: vec![1],
1329            metadata: HashMap::new(),
1330        });
1331
1332        // op2: parallel to op1 (30ms, slow path - on critical path)
1333        trace.add_entry(TraceEntry {
1334            entry_id: 2,
1335            node_id: 2,
1336            operation: "op2".to_string(),
1337            start_time: Instant::now(),
1338            duration: Duration::from_millis(30),
1339            input_ids: vec![0],
1340            output_ids: vec![2],
1341            metadata: HashMap::new(),
1342        });
1343
1344        // op3: join (40ms)
1345        trace.add_entry(TraceEntry {
1346            entry_id: 3,
1347            node_id: 3,
1348            operation: "op3".to_string(),
1349            start_time: Instant::now(),
1350            duration: Duration::from_millis(40),
1351            input_ids: vec![1, 2],
1352            output_ids: vec![3],
1353            metadata: HashMap::new(),
1354        });
1355
1356        let parallelism = trace.parallelism_factor();
1357        // Total duration: 100ms
1358        // Critical path: 10 + 30 + 40 = 80ms
1359        // Factor: 100 / 80 = 1.25
1360        assert!((parallelism - 1.25).abs() < 0.01);
1361    }
1362
1363    #[test]
1364    fn test_critical_path_complex_graph() {
1365        // More complex dependency graph with multiple levels
1366        let mut trace = ExecutionTrace::new();
1367
1368        // Level 0: single root
1369        trace.add_entry(TraceEntry {
1370            entry_id: 0,
1371            node_id: 0,
1372            operation: "root".to_string(),
1373            start_time: Instant::now(),
1374            duration: Duration::from_millis(5),
1375            input_ids: vec![],
1376            output_ids: vec![0],
1377            metadata: HashMap::new(),
1378        });
1379
1380        // Level 1: three parallel branches
1381        for i in 1..=3 {
1382            trace.add_entry(TraceEntry {
1383                entry_id: i,
1384                node_id: i,
1385                operation: format!("branch{}", i),
1386                start_time: Instant::now(),
1387                duration: Duration::from_millis((i as u64) * 10),
1388                input_ids: vec![0],
1389                output_ids: vec![i],
1390                metadata: HashMap::new(),
1391            });
1392        }
1393
1394        // Level 2: merge two slowest branches
1395        trace.add_entry(TraceEntry {
1396            entry_id: 4,
1397            node_id: 4,
1398            operation: "merge".to_string(),
1399            start_time: Instant::now(),
1400            duration: Duration::from_millis(15),
1401            input_ids: vec![2, 3],
1402            output_ids: vec![4],
1403            metadata: HashMap::new(),
1404        });
1405
1406        let critical_path = trace.critical_path();
1407        // Critical path should be: root -> branch3 -> merge
1408        assert!(critical_path.len() >= 3);
1409        assert_eq!(critical_path[0].operation, "root");
1410        // The last entry should be "merge" (highest finish time)
1411        assert_eq!(critical_path[critical_path.len() - 1].operation, "merge");
1412    }
1413}