Skip to main content

tensorlogic_infer/
debug.rs

1//! Debugging utilities for execution tracing and tensor inspection.
2//!
3//! This module provides comprehensive debugging tools for TensorLogic execution:
4//!
5//! - **ExecutionTracer**: Record execution flow through computation graphs
6//! - **TensorInspector**: Examine intermediate tensor values and statistics
7//! - **BreakpointManager**: Pause execution at specific nodes for inspection
8//! - **ExecutionRecorder**: Record full execution history for replay and analysis
9//!
10//! # Example
11//!
12//! ```rust
13//! use tensorlogic_infer::debug::{ExecutionTracer, TensorInspector, BreakpointManager};
14//!
15//! // Set up tracing
16//! let mut tracer = ExecutionTracer::new();
17//! tracer.enable();
18//!
19//! // Add breakpoints
20//! let mut breakpoints = BreakpointManager::new();
21//! breakpoints.add_node_breakpoint(5);
22//!
23//! // Execute with debugging
24//! // ... execution code ...
25//!
26//! // Analyze trace
27//! let trace = tracer.get_trace();
28//! for entry in trace.entries() {
29//!     println!("Node {}: {}ms", entry.node_id, entry.duration_ms());
30//! }
31//! ```
32
33use std::cmp::Reverse;
34use std::collections::HashMap;
35use std::fmt;
36use std::time::{Duration, Instant};
37
38/// Execution trace entry recording a single operation.
39#[derive(Debug, Clone)]
40pub struct TraceEntry {
41    /// Unique entry ID
42    pub entry_id: usize,
43    /// Node ID in the computation graph
44    pub node_id: usize,
45    /// Operation name
46    pub operation: String,
47    /// Start timestamp
48    pub start_time: Instant,
49    /// Duration of execution
50    pub duration: Duration,
51    /// Input tensor IDs
52    pub input_ids: Vec<usize>,
53    /// Output tensor IDs
54    pub output_ids: Vec<usize>,
55    /// Additional metadata
56    pub metadata: HashMap<String, String>,
57}
58
59impl TraceEntry {
60    /// Get duration in milliseconds.
61    pub fn duration_ms(&self) -> f64 {
62        self.duration.as_secs_f64() * 1000.0
63    }
64
65    /// Get duration in microseconds.
66    pub fn duration_us(&self) -> f64 {
67        self.duration.as_secs_f64() * 1_000_000.0
68    }
69}
70
71/// Execution trace containing recorded operations.
72#[derive(Debug, Clone)]
73pub struct ExecutionTrace {
74    entries: Vec<TraceEntry>,
75    total_duration: Duration,
76    graph_id: Option<usize>,
77}
78
79impl ExecutionTrace {
80    /// Create a new empty trace.
81    pub fn new() -> Self {
82        Self {
83            entries: Vec::new(),
84            total_duration: Duration::ZERO,
85            graph_id: None,
86        }
87    }
88
89    /// Set the graph ID for this trace.
90    pub fn with_graph_id(mut self, graph_id: usize) -> Self {
91        self.graph_id = Some(graph_id);
92        self
93    }
94
95    /// Add a trace entry.
96    pub fn add_entry(&mut self, entry: TraceEntry) {
97        self.total_duration += entry.duration;
98        self.entries.push(entry);
99    }
100
101    /// Get all trace entries.
102    pub fn entries(&self) -> &[TraceEntry] {
103        &self.entries
104    }
105
106    /// Get the total execution duration.
107    pub fn total_duration(&self) -> Duration {
108        self.total_duration
109    }
110
111    /// Get the total duration in milliseconds.
112    pub fn total_duration_ms(&self) -> f64 {
113        self.total_duration.as_secs_f64() * 1000.0
114    }
115
116    /// Get entries for a specific node.
117    pub fn entries_for_node(&self, node_id: usize) -> Vec<&TraceEntry> {
118        self.entries
119            .iter()
120            .filter(|e| e.node_id == node_id)
121            .collect()
122    }
123
124    /// Get the critical path (longest chain of dependent operations).
125    ///
126    /// Uses dependency tracking via input/output tensor IDs to compute the critical path.
127    /// The critical path is the longest chain of operations where each depends on the previous,
128    /// determining the minimum possible execution time.
129    pub fn critical_path(&self) -> Vec<&TraceEntry> {
130        if self.entries.is_empty() {
131            return Vec::new();
132        }
133
134        let n = self.entries.len();
135
136        // Map output tensors to the entries that produce them
137        let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
138        for (idx, entry) in self.entries.iter().enumerate() {
139            for &output_id in &entry.output_ids {
140                tensor_producers.insert(output_id, idx);
141            }
142        }
143
144        // Build reverse dependencies: for each entry, which entries must complete before it
145        let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
146        for (idx, entry) in self.entries.iter().enumerate() {
147            for &input_id in &entry.input_ids {
148                if let Some(&producer_idx) = tensor_producers.get(&input_id) {
149                    if producer_idx < n {
150                        predecessors[idx].push(producer_idx);
151                    }
152                }
153            }
154        }
155
156        // Compute earliest start time (EST) and earliest finish time (EFT) for each node
157        // EFT[i] = max(EFT[pred] for pred in predecessors[i]) + duration[i]
158        let mut eft = vec![Duration::ZERO; n];
159        let mut predecessor_on_critical_path = vec![None; n];
160
161        // Use iterative approach with fixed-point iteration to handle all dependencies
162        let mut changed = true;
163        for _ in 0..n {
164            // Maximum n iterations needed
165            if !changed {
166                break;
167            }
168            changed = false;
169
170            for idx in 0..n {
171                let mut max_pred_eft = Duration::ZERO;
172                let mut critical_pred = None;
173
174                for &pred_idx in &predecessors[idx] {
175                    if eft[pred_idx] > max_pred_eft {
176                        max_pred_eft = eft[pred_idx];
177                        critical_pred = Some(pred_idx);
178                    }
179                }
180
181                let new_eft = max_pred_eft + self.entries[idx].duration;
182                if new_eft > eft[idx] {
183                    eft[idx] = new_eft;
184                    predecessor_on_critical_path[idx] = critical_pred;
185                    changed = true;
186                }
187            }
188        }
189
190        // Find the node with maximum EFT (end of critical path)
191        let critical_end_idx = eft
192            .iter()
193            .enumerate()
194            .max_by_key(|(_, &time)| time)
195            .map(|(idx, _)| idx)
196            .unwrap_or(0);
197
198        // Backtrack from the end to find the critical path
199        let mut critical_path_indices = Vec::new();
200        let mut current = Some(critical_end_idx);
201
202        while let Some(idx) = current {
203            critical_path_indices.push(idx);
204            current = predecessor_on_critical_path[idx];
205        }
206
207        // Reverse to get path from start to end
208        critical_path_indices.reverse();
209
210        // Convert indices to entry references
211        critical_path_indices
212            .iter()
213            .map(|&idx| &self.entries[idx])
214            .collect()
215    }
216
217    /// Get the total critical path duration.
218    pub fn critical_path_duration(&self) -> Duration {
219        self.critical_path().iter().map(|e| e.duration).sum()
220    }
221
222    /// Calculate the parallelism factor (total work / critical path time).
223    /// Values > 1 indicate potential for parallel execution.
224    pub fn parallelism_factor(&self) -> f64 {
225        let critical_time = self.critical_path_duration();
226        if critical_time.as_secs_f64() == 0.0 {
227            return 1.0;
228        }
229        self.total_duration.as_secs_f64() / critical_time.as_secs_f64()
230    }
231
232    /// Get operations sorted by duration (slowest first).
233    pub fn slowest_operations(&self, limit: usize) -> Vec<&TraceEntry> {
234        let mut sorted: Vec<_> = self.entries.iter().collect();
235        sorted.sort_by_key(|b| Reverse(b.duration));
236        sorted.into_iter().take(limit).collect()
237    }
238
239    /// Generate a summary report.
240    pub fn summary(&self) -> TraceSummary {
241        TraceSummary::from_trace(self)
242    }
243}
244
245impl Default for ExecutionTrace {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251/// Summary statistics for an execution trace.
252#[derive(Debug, Clone)]
253pub struct TraceSummary {
254    /// Total number of operations
255    pub total_operations: usize,
256    /// Total execution time in milliseconds
257    pub total_time_ms: f64,
258    /// Average operation time in milliseconds
259    pub avg_time_ms: f64,
260    /// Slowest operation time in milliseconds
261    pub max_time_ms: f64,
262    /// Fastest operation time in milliseconds
263    pub min_time_ms: f64,
264    /// Operation counts by type
265    pub operation_counts: HashMap<String, usize>,
266}
267
268impl TraceSummary {
269    /// Create a summary from a trace.
270    pub fn from_trace(trace: &ExecutionTrace) -> Self {
271        let entries = trace.entries();
272        let total_operations = entries.len();
273
274        let total_time_ms = trace.total_duration_ms();
275        let avg_time_ms = if total_operations > 0 {
276            total_time_ms / total_operations as f64
277        } else {
278            0.0
279        };
280
281        let max_time_ms = entries.iter().map(|e| e.duration_ms()).fold(0.0, f64::max);
282        let min_time_ms = entries
283            .iter()
284            .map(|e| e.duration_ms())
285            .fold(f64::MAX, f64::min);
286
287        let mut operation_counts: HashMap<String, usize> = HashMap::new();
288        for entry in entries {
289            *operation_counts.entry(entry.operation.clone()).or_insert(0) += 1;
290        }
291
292        Self {
293            total_operations,
294            total_time_ms,
295            avg_time_ms,
296            max_time_ms,
297            min_time_ms,
298            operation_counts,
299        }
300    }
301}
302
303impl fmt::Display for TraceSummary {
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        writeln!(f, "Execution Trace Summary")?;
306        writeln!(f, "=======================")?;
307        writeln!(f, "Total operations: {}", self.total_operations)?;
308        writeln!(f, "Total time: {:.2} ms", self.total_time_ms)?;
309        writeln!(f, "Average time: {:.2} ms", self.avg_time_ms)?;
310        writeln!(f, "Max time: {:.2} ms", self.max_time_ms)?;
311        writeln!(f, "Min time: {:.2} ms", self.min_time_ms)?;
312        writeln!(f, "\nOperation Counts:")?;
313        let mut sorted_ops: Vec<_> = self.operation_counts.iter().collect();
314        sorted_ops.sort_by_key(|(_, count)| std::cmp::Reverse(**count));
315        for (op, count) in sorted_ops {
316            writeln!(f, "  {}: {}", op, count)?;
317        }
318        Ok(())
319    }
320}
321
322/// Execution tracer for recording operation flow.
323pub struct ExecutionTracer {
324    enabled: bool,
325    current_trace: ExecutionTrace,
326    traces: Vec<ExecutionTrace>,
327    next_entry_id: usize,
328}
329
330impl ExecutionTracer {
331    /// Create a new execution tracer.
332    pub fn new() -> Self {
333        Self {
334            enabled: false,
335            current_trace: ExecutionTrace::new(),
336            traces: Vec::new(),
337            next_entry_id: 0,
338        }
339    }
340
341    /// Enable tracing.
342    pub fn enable(&mut self) {
343        self.enabled = true;
344    }
345
346    /// Disable tracing.
347    pub fn disable(&mut self) {
348        self.enabled = false;
349    }
350
351    /// Check if tracing is enabled.
352    pub fn is_enabled(&self) -> bool {
353        self.enabled
354    }
355
356    /// Start a new trace (finalizes current trace if any).
357    pub fn start_trace(&mut self, graph_id: Option<usize>) {
358        if !self.current_trace.entries.is_empty() {
359            self.finalize_trace();
360        }
361        self.current_trace = ExecutionTrace::new();
362        if let Some(id) = graph_id {
363            self.current_trace.graph_id = Some(id);
364        }
365    }
366
367    /// Finalize the current trace and store it.
368    pub fn finalize_trace(&mut self) {
369        if !self.current_trace.entries.is_empty() {
370            let trace = std::mem::take(&mut self.current_trace);
371            self.traces.push(trace);
372        }
373    }
374
375    /// Record the start of an operation.
376    pub fn record_operation_start(
377        &mut self,
378        _node_id: usize,
379        _operation: impl Into<String>,
380        _input_ids: Vec<usize>,
381    ) -> OperationHandle {
382        if !self.enabled {
383            return OperationHandle {
384                entry_id: None,
385                start_time: Instant::now(),
386            };
387        }
388
389        let entry_id = self.next_entry_id;
390        self.next_entry_id += 1;
391
392        OperationHandle {
393            entry_id: Some(entry_id),
394            start_time: Instant::now(),
395        }
396    }
397
398    /// Record the end of an operation.
399    pub fn record_operation_end(
400        &mut self,
401        handle: OperationHandle,
402        node_id: usize,
403        operation: impl Into<String>,
404        input_ids: Vec<usize>,
405        output_ids: Vec<usize>,
406        metadata: HashMap<String, String>,
407    ) {
408        if !self.enabled || handle.entry_id.is_none() {
409            return;
410        }
411
412        let duration = handle.start_time.elapsed();
413        let entry = TraceEntry {
414            entry_id: handle
415                .entry_id
416                .expect("entry_id is Some after is_none guard above"),
417            node_id,
418            operation: operation.into(),
419            start_time: handle.start_time,
420            duration,
421            input_ids,
422            output_ids,
423            metadata,
424        };
425
426        self.current_trace.add_entry(entry);
427    }
428
429    /// Get the current trace.
430    pub fn get_trace(&self) -> &ExecutionTrace {
431        &self.current_trace
432    }
433
434    /// Get all recorded traces.
435    pub fn get_all_traces(&self) -> &[ExecutionTrace] {
436        &self.traces
437    }
438
439    /// Clear all traces.
440    pub fn clear(&mut self) {
441        self.current_trace = ExecutionTrace::new();
442        self.traces.clear();
443        self.next_entry_id = 0;
444    }
445}
446
447impl Default for ExecutionTracer {
448    fn default() -> Self {
449        Self::new()
450    }
451}
452
453/// Handle for an in-progress operation recording.
454pub struct OperationHandle {
455    entry_id: Option<usize>,
456    start_time: Instant,
457}
458
459/// Tensor statistics for inspection.
460#[derive(Debug, Clone)]
461pub struct TensorStats {
462    /// Tensor ID
463    pub tensor_id: usize,
464    /// Shape of the tensor
465    pub shape: Vec<usize>,
466    /// Number of elements
467    pub num_elements: usize,
468    /// Data type
469    pub dtype: String,
470    /// Minimum value (if computed)
471    pub min_value: Option<f64>,
472    /// Maximum value (if computed)
473    pub max_value: Option<f64>,
474    /// Mean value (if computed)
475    pub mean_value: Option<f64>,
476    /// Standard deviation (if computed)
477    pub std_dev: Option<f64>,
478    /// Number of NaN values
479    pub num_nans: Option<usize>,
480    /// Number of infinite values
481    pub num_infs: Option<usize>,
482}
483
484impl TensorStats {
485    /// Create tensor stats with basic information.
486    pub fn new(tensor_id: usize, shape: Vec<usize>, dtype: impl Into<String>) -> Self {
487        let num_elements = shape.iter().product();
488        Self {
489            tensor_id,
490            shape,
491            num_elements,
492            dtype: dtype.into(),
493            min_value: None,
494            max_value: None,
495            mean_value: None,
496            std_dev: None,
497            num_nans: None,
498            num_infs: None,
499        }
500    }
501
502    /// Add computed statistics.
503    pub fn with_statistics(
504        mut self,
505        min: f64,
506        max: f64,
507        mean: f64,
508        std_dev: f64,
509        num_nans: usize,
510        num_infs: usize,
511    ) -> Self {
512        self.min_value = Some(min);
513        self.max_value = Some(max);
514        self.mean_value = Some(mean);
515        self.std_dev = Some(std_dev);
516        self.num_nans = Some(num_nans);
517        self.num_infs = Some(num_infs);
518        self
519    }
520
521    /// Check if the tensor has numerical issues.
522    pub fn has_numerical_issues(&self) -> bool {
523        self.num_nans.unwrap_or(0) > 0 || self.num_infs.unwrap_or(0) > 0
524    }
525}
526
527impl fmt::Display for TensorStats {
528    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
529        writeln!(f, "Tensor {} Stats:", self.tensor_id)?;
530        writeln!(f, "  Shape: {:?}", self.shape)?;
531        writeln!(f, "  Elements: {}", self.num_elements)?;
532        writeln!(f, "  DType: {}", self.dtype)?;
533        if let Some(min) = self.min_value {
534            writeln!(f, "  Min: {:.6}", min)?;
535        }
536        if let Some(max) = self.max_value {
537            writeln!(f, "  Max: {:.6}", max)?;
538        }
539        if let Some(mean) = self.mean_value {
540            writeln!(f, "  Mean: {:.6}", mean)?;
541        }
542        if let Some(std) = self.std_dev {
543            writeln!(f, "  Std Dev: {:.6}", std)?;
544        }
545        if let Some(nans) = self.num_nans {
546            if nans > 0 {
547                writeln!(f, "  ⚠️  NaNs: {}", nans)?;
548            }
549        }
550        if let Some(infs) = self.num_infs {
551            if infs > 0 {
552                writeln!(f, "  ⚠️  Infs: {}", infs)?;
553            }
554        }
555        Ok(())
556    }
557}
558
559/// Tensor inspector for examining intermediate values.
560pub struct TensorInspector {
561    enabled: bool,
562    tensor_stats: HashMap<usize, TensorStats>,
563    watch_list: Vec<usize>,
564}
565
566impl TensorInspector {
567    /// Create a new tensor inspector.
568    pub fn new() -> Self {
569        Self {
570            enabled: false,
571            tensor_stats: HashMap::new(),
572            watch_list: Vec::new(),
573        }
574    }
575
576    /// Enable inspection.
577    pub fn enable(&mut self) {
578        self.enabled = true;
579    }
580
581    /// Disable inspection.
582    pub fn disable(&mut self) {
583        self.enabled = false;
584    }
585
586    /// Check if inspection is enabled.
587    pub fn is_enabled(&self) -> bool {
588        self.enabled
589    }
590
591    /// Add a tensor to the watch list.
592    pub fn watch(&mut self, tensor_id: usize) {
593        if !self.watch_list.contains(&tensor_id) {
594            self.watch_list.push(tensor_id);
595        }
596    }
597
598    /// Remove a tensor from the watch list.
599    pub fn unwatch(&mut self, tensor_id: usize) {
600        self.watch_list.retain(|&id| id != tensor_id);
601    }
602
603    /// Clear the watch list.
604    pub fn clear_watch_list(&mut self) {
605        self.watch_list.clear();
606    }
607
608    /// Check if a tensor should be inspected.
609    pub fn should_inspect(&self, tensor_id: usize) -> bool {
610        self.enabled && (self.watch_list.is_empty() || self.watch_list.contains(&tensor_id))
611    }
612
613    /// Record tensor statistics.
614    pub fn record_stats(&mut self, stats: TensorStats) {
615        if !self.enabled {
616            return;
617        }
618        self.tensor_stats.insert(stats.tensor_id, stats);
619    }
620
621    /// Get statistics for a tensor.
622    pub fn get_stats(&self, tensor_id: usize) -> Option<&TensorStats> {
623        self.tensor_stats.get(&tensor_id)
624    }
625
626    /// Get all recorded statistics.
627    pub fn get_all_stats(&self) -> &HashMap<usize, TensorStats> {
628        &self.tensor_stats
629    }
630
631    /// Find tensors with numerical issues.
632    pub fn find_problematic_tensors(&self) -> Vec<&TensorStats> {
633        self.tensor_stats
634            .values()
635            .filter(|stats| stats.has_numerical_issues())
636            .collect()
637    }
638
639    /// Clear all recorded statistics.
640    pub fn clear(&mut self) {
641        self.tensor_stats.clear();
642    }
643}
644
645impl Default for TensorInspector {
646    fn default() -> Self {
647        Self::new()
648    }
649}
650
651/// Breakpoint type for execution control.
652#[derive(Debug, Clone, PartialEq, Eq)]
653pub enum Breakpoint {
654    /// Break at a specific node
655    Node(usize),
656    /// Break at a specific operation type
657    Operation(String),
658    /// Break when a tensor has numerical issues
659    NumericalIssue,
660    /// Break when execution time exceeds threshold (in microseconds)
661    TimeThreshold(u64),
662    /// Conditional breakpoint with custom predicate
663    Conditional(String), // Store predicate as string for simplicity
664}
665
666/// Breakpoint hit information.
667#[derive(Debug, Clone)]
668pub struct BreakpointHit {
669    /// The breakpoint that was hit
670    pub breakpoint: Breakpoint,
671    /// Node ID where execution paused
672    pub node_id: usize,
673    /// Current execution time in microseconds
674    pub elapsed_us: u64,
675    /// Additional context
676    pub context: HashMap<String, String>,
677}
678
679/// Manager for execution breakpoints.
680pub struct BreakpointManager {
681    enabled: bool,
682    breakpoints: Vec<Breakpoint>,
683    hits: Vec<BreakpointHit>,
684    continue_execution: bool,
685}
686
687impl BreakpointManager {
688    /// Create a new breakpoint manager.
689    pub fn new() -> Self {
690        Self {
691            enabled: false,
692            breakpoints: Vec::new(),
693            hits: Vec::new(),
694            continue_execution: true,
695        }
696    }
697
698    /// Enable breakpoint checking.
699    pub fn enable(&mut self) {
700        self.enabled = true;
701    }
702
703    /// Disable breakpoint checking.
704    pub fn disable(&mut self) {
705        self.enabled = false;
706    }
707
708    /// Check if breakpoint checking is enabled.
709    pub fn is_enabled(&self) -> bool {
710        self.enabled
711    }
712
713    /// Add a node breakpoint.
714    pub fn add_node_breakpoint(&mut self, node_id: usize) {
715        self.breakpoints.push(Breakpoint::Node(node_id));
716    }
717
718    /// Add an operation breakpoint.
719    pub fn add_operation_breakpoint(&mut self, operation: impl Into<String>) {
720        self.breakpoints
721            .push(Breakpoint::Operation(operation.into()));
722    }
723
724    /// Add a numerical issue breakpoint.
725    pub fn add_numerical_issue_breakpoint(&mut self) {
726        self.breakpoints.push(Breakpoint::NumericalIssue);
727    }
728
729    /// Add a time threshold breakpoint.
730    pub fn add_time_threshold_breakpoint(&mut self, threshold_us: u64) {
731        self.breakpoints
732            .push(Breakpoint::TimeThreshold(threshold_us));
733    }
734
735    /// Remove a breakpoint.
736    pub fn remove_breakpoint(&mut self, breakpoint: &Breakpoint) {
737        self.breakpoints.retain(|bp| bp != breakpoint);
738    }
739
740    /// Clear all breakpoints.
741    pub fn clear_breakpoints(&mut self) {
742        self.breakpoints.clear();
743    }
744
745    /// Get all breakpoints.
746    pub fn get_breakpoints(&self) -> &[Breakpoint] {
747        &self.breakpoints
748    }
749
750    /// Check if execution should break at this point.
751    pub fn should_break(
752        &mut self,
753        node_id: usize,
754        operation: &str,
755        elapsed_us: u64,
756        has_numerical_issue: bool,
757    ) -> Option<BreakpointHit> {
758        if !self.enabled || !self.continue_execution {
759            return None;
760        }
761
762        for breakpoint in &self.breakpoints {
763            let should_break = match breakpoint {
764                Breakpoint::Node(bp_node_id) => *bp_node_id == node_id,
765                Breakpoint::Operation(bp_op) => bp_op == operation,
766                Breakpoint::NumericalIssue => has_numerical_issue,
767                Breakpoint::TimeThreshold(threshold) => elapsed_us > *threshold,
768                Breakpoint::Conditional(_) => false, // Not implemented yet
769            };
770
771            if should_break {
772                let hit = BreakpointHit {
773                    breakpoint: breakpoint.clone(),
774                    node_id,
775                    elapsed_us,
776                    context: HashMap::new(),
777                };
778                self.hits.push(hit.clone());
779                self.continue_execution = false;
780                return Some(hit);
781            }
782        }
783
784        None
785    }
786
787    /// Continue execution after a breakpoint hit.
788    pub fn continue_execution(&mut self) {
789        self.continue_execution = true;
790    }
791
792    /// Get all breakpoint hits.
793    pub fn get_hits(&self) -> &[BreakpointHit] {
794        &self.hits
795    }
796
797    /// Clear all breakpoint hits.
798    pub fn clear_hits(&mut self) {
799        self.hits.clear();
800    }
801}
802
803impl Default for BreakpointManager {
804    fn default() -> Self {
805        Self::new()
806    }
807}
808
809/// Full execution recorder for replay and analysis.
810pub struct ExecutionRecorder {
811    enabled: bool,
812    tracer: ExecutionTracer,
813    inspector: TensorInspector,
814    breakpoints: BreakpointManager,
815}
816
817impl ExecutionRecorder {
818    /// Create a new execution recorder.
819    pub fn new() -> Self {
820        Self {
821            enabled: false,
822            tracer: ExecutionTracer::new(),
823            inspector: TensorInspector::new(),
824            breakpoints: BreakpointManager::new(),
825        }
826    }
827
828    /// Enable all recording features.
829    pub fn enable(&mut self) {
830        self.enabled = true;
831        self.tracer.enable();
832        self.inspector.enable();
833        self.breakpoints.enable();
834    }
835
836    /// Disable all recording features.
837    pub fn disable(&mut self) {
838        self.enabled = false;
839        self.tracer.disable();
840        self.inspector.disable();
841        self.breakpoints.disable();
842    }
843
844    /// Get the tracer.
845    pub fn tracer(&mut self) -> &mut ExecutionTracer {
846        &mut self.tracer
847    }
848
849    /// Get the inspector.
850    pub fn inspector(&mut self) -> &mut TensorInspector {
851        &mut self.inspector
852    }
853
854    /// Get the breakpoint manager.
855    pub fn breakpoints(&mut self) -> &mut BreakpointManager {
856        &mut self.breakpoints
857    }
858
859    /// Clear all recorded data.
860    pub fn clear(&mut self) {
861        self.tracer.clear();
862        self.inspector.clear();
863        self.breakpoints.clear_hits();
864    }
865
866    /// Generate a comprehensive execution report.
867    pub fn generate_report(&self) -> ExecutionReport {
868        ExecutionReport {
869            trace_summary: self.tracer.get_trace().summary(),
870            problematic_tensors: self.inspector.find_problematic_tensors().len(),
871            breakpoint_hits: self.breakpoints.get_hits().len(),
872        }
873    }
874}
875
876impl Default for ExecutionRecorder {
877    fn default() -> Self {
878        Self::new()
879    }
880}
881
882/// Comprehensive execution report.
883#[derive(Debug, Clone)]
884pub struct ExecutionReport {
885    /// Trace summary
886    pub trace_summary: TraceSummary,
887    /// Number of tensors with numerical issues
888    pub problematic_tensors: usize,
889    /// Number of breakpoint hits
890    pub breakpoint_hits: usize,
891}
892
893impl fmt::Display for ExecutionReport {
894    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
895        writeln!(f, "{}", self.trace_summary)?;
896        writeln!(f, "\nDebug Information:")?;
897        writeln!(f, "  Problematic tensors: {}", self.problematic_tensors)?;
898        writeln!(f, "  Breakpoint hits: {}", self.breakpoint_hits)?;
899        Ok(())
900    }
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    #[test]
908    fn test_execution_tracer() {
909        let mut tracer = ExecutionTracer::new();
910        assert!(!tracer.is_enabled());
911
912        tracer.enable();
913        assert!(tracer.is_enabled());
914
915        tracer.start_trace(Some(1));
916        let handle = tracer.record_operation_start(0, "einsum", vec![0, 1]);
917        std::thread::sleep(Duration::from_millis(10));
918        tracer.record_operation_end(handle, 0, "einsum", vec![0, 1], vec![2], HashMap::new());
919
920        let trace = tracer.get_trace();
921        assert_eq!(trace.entries().len(), 1);
922        assert!(trace.total_duration_ms() >= 10.0);
923    }
924
925    #[test]
926    fn test_trace_summary() {
927        let mut trace = ExecutionTrace::new();
928        let entry = TraceEntry {
929            entry_id: 0,
930            node_id: 0,
931            operation: "einsum".to_string(),
932            start_time: Instant::now(),
933            duration: Duration::from_millis(10),
934            input_ids: vec![0],
935            output_ids: vec![1],
936            metadata: HashMap::new(),
937        };
938        trace.add_entry(entry);
939
940        let summary = trace.summary();
941        assert_eq!(summary.total_operations, 1);
942        assert!(summary.total_time_ms >= 10.0);
943    }
944
945    #[test]
946    fn test_tensor_inspector() {
947        let mut inspector = TensorInspector::new();
948        inspector.enable();
949
950        let stats =
951            TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
952
953        inspector.record_stats(stats.clone());
954        assert_eq!(inspector.get_stats(0).expect("unwrap").tensor_id, 0);
955        assert!(!stats.has_numerical_issues());
956    }
957
958    #[test]
959    fn test_tensor_numerical_issues() {
960        let stats = TensorStats::new(0, vec![2, 3], "f64").with_statistics(
961            0.0,
962            f64::INFINITY,
963            0.5,
964            0.25,
965            1,
966            1,
967        );
968
969        assert!(stats.has_numerical_issues());
970    }
971
972    #[test]
973    fn test_breakpoint_manager() {
974        let mut manager = BreakpointManager::new();
975        manager.enable();
976        manager.add_node_breakpoint(5);
977
978        let hit = manager.should_break(5, "einsum", 1000, false);
979        assert!(hit.is_some());
980        assert_eq!(hit.expect("unwrap").node_id, 5);
981
982        manager.continue_execution();
983        let hit2 = manager.should_break(5, "einsum", 1000, false);
984        assert!(hit2.is_some());
985    }
986
987    #[test]
988    fn test_operation_breakpoint() {
989        let mut manager = BreakpointManager::new();
990        manager.enable();
991        manager.add_operation_breakpoint("matmul");
992
993        let hit = manager.should_break(1, "matmul", 1000, false);
994        assert!(hit.is_some());
995
996        let no_hit = manager.should_break(2, "add", 1000, false);
997        assert!(no_hit.is_none());
998    }
999
1000    #[test]
1001    fn test_time_threshold_breakpoint() {
1002        let mut manager = BreakpointManager::new();
1003        manager.enable();
1004        manager.add_time_threshold_breakpoint(5000);
1005
1006        let no_hit = manager.should_break(1, "op", 4000, false);
1007        assert!(no_hit.is_none());
1008
1009        let hit = manager.should_break(1, "op", 6000, false);
1010        assert!(hit.is_some());
1011    }
1012
1013    #[test]
1014    fn test_numerical_issue_breakpoint() {
1015        let mut manager = BreakpointManager::new();
1016        manager.enable();
1017        manager.add_numerical_issue_breakpoint();
1018
1019        let no_hit = manager.should_break(1, "op", 1000, false);
1020        assert!(no_hit.is_none());
1021
1022        let hit = manager.should_break(1, "op", 1000, true);
1023        assert!(hit.is_some());
1024    }
1025
1026    #[test]
1027    fn test_execution_recorder() {
1028        let mut recorder = ExecutionRecorder::new();
1029        recorder.enable();
1030
1031        assert!(recorder.tracer().is_enabled());
1032        assert!(recorder.inspector().is_enabled());
1033        assert!(recorder.breakpoints().is_enabled());
1034
1035        recorder.clear();
1036        let report = recorder.generate_report();
1037        assert_eq!(report.trace_summary.total_operations, 0);
1038    }
1039
1040    #[test]
1041    fn test_slowest_operations() {
1042        let mut trace = ExecutionTrace::new();
1043        for i in 0..5 {
1044            let entry = TraceEntry {
1045                entry_id: i,
1046                node_id: i,
1047                operation: format!("op{}", i),
1048                start_time: Instant::now(),
1049                duration: Duration::from_millis((i as u64 + 1) * 10),
1050                input_ids: vec![],
1051                output_ids: vec![],
1052                metadata: HashMap::new(),
1053            };
1054            trace.add_entry(entry);
1055        }
1056
1057        let slowest = trace.slowest_operations(3);
1058        assert_eq!(slowest.len(), 3);
1059        assert_eq!(slowest[0].node_id, 4); // Slowest
1060        assert_eq!(slowest[1].node_id, 3);
1061        assert_eq!(slowest[2].node_id, 2);
1062    }
1063
1064    #[test]
1065    fn test_watch_list() {
1066        let mut inspector = TensorInspector::new();
1067        inspector.enable();
1068
1069        inspector.watch(1);
1070        inspector.watch(2);
1071
1072        assert!(inspector.should_inspect(1));
1073        assert!(inspector.should_inspect(2));
1074        assert!(!inspector.should_inspect(3));
1075
1076        inspector.unwatch(1);
1077        assert!(!inspector.should_inspect(1));
1078        assert!(inspector.should_inspect(2));
1079
1080        inspector.clear_watch_list();
1081        // When watch list is empty, all tensors should be inspected
1082        assert!(inspector.should_inspect(5));
1083    }
1084
1085    #[test]
1086    fn test_trace_entries_for_node() {
1087        let mut trace = ExecutionTrace::new();
1088        for i in 0..3 {
1089            let entry = TraceEntry {
1090                entry_id: i,
1091                node_id: i % 2,
1092                operation: "op".to_string(),
1093                start_time: Instant::now(),
1094                duration: Duration::from_millis(10),
1095                input_ids: vec![],
1096                output_ids: vec![],
1097                metadata: HashMap::new(),
1098            };
1099            trace.add_entry(entry);
1100        }
1101
1102        let node_0_entries = trace.entries_for_node(0);
1103        assert_eq!(node_0_entries.len(), 2);
1104
1105        let node_1_entries = trace.entries_for_node(1);
1106        assert_eq!(node_1_entries.len(), 1);
1107    }
1108
1109    #[test]
1110    fn test_critical_path_linear_chain() {
1111        // Test linear chain: op0 -> op1 -> op2
1112        let mut trace = ExecutionTrace::new();
1113
1114        // op0: produces tensor 0
1115        trace.add_entry(TraceEntry {
1116            entry_id: 0,
1117            node_id: 0,
1118            operation: "op0".to_string(),
1119            start_time: Instant::now(),
1120            duration: Duration::from_millis(10),
1121            input_ids: vec![],
1122            output_ids: vec![0],
1123            metadata: HashMap::new(),
1124        });
1125
1126        // op1: consumes tensor 0, produces tensor 1
1127        trace.add_entry(TraceEntry {
1128            entry_id: 1,
1129            node_id: 1,
1130            operation: "op1".to_string(),
1131            start_time: Instant::now(),
1132            duration: Duration::from_millis(20),
1133            input_ids: vec![0],
1134            output_ids: vec![1],
1135            metadata: HashMap::new(),
1136        });
1137
1138        // op2: consumes tensor 1, produces tensor 2
1139        trace.add_entry(TraceEntry {
1140            entry_id: 2,
1141            node_id: 2,
1142            operation: "op2".to_string(),
1143            start_time: Instant::now(),
1144            duration: Duration::from_millis(15),
1145            input_ids: vec![1],
1146            output_ids: vec![2],
1147            metadata: HashMap::new(),
1148        });
1149
1150        let critical_path = trace.critical_path();
1151        assert_eq!(critical_path.len(), 3); // All operations are on critical path
1152        assert_eq!(critical_path[0].node_id, 0);
1153        assert_eq!(critical_path[1].node_id, 1);
1154        assert_eq!(critical_path[2].node_id, 2);
1155
1156        let cp_duration = trace.critical_path_duration();
1157        assert_eq!(cp_duration, Duration::from_millis(45)); // 10 + 20 + 15
1158    }
1159
1160    #[test]
1161    fn test_critical_path_parallel_operations() {
1162        // Test parallel operations: op0 produces two independent outputs
1163        let mut trace = ExecutionTrace::new();
1164
1165        // op0: produces tensors 0 and 1
1166        trace.add_entry(TraceEntry {
1167            entry_id: 0,
1168            node_id: 0,
1169            operation: "op0".to_string(),
1170            start_time: Instant::now(),
1171            duration: Duration::from_millis(10),
1172            input_ids: vec![],
1173            output_ids: vec![0, 1],
1174            metadata: HashMap::new(),
1175        });
1176
1177        // op1: consumes tensor 0 (fast path)
1178        trace.add_entry(TraceEntry {
1179            entry_id: 1,
1180            node_id: 1,
1181            operation: "op1".to_string(),
1182            start_time: Instant::now(),
1183            duration: Duration::from_millis(5),
1184            input_ids: vec![0],
1185            output_ids: vec![2],
1186            metadata: HashMap::new(),
1187        });
1188
1189        // op2: consumes tensor 1 (slow path - critical)
1190        trace.add_entry(TraceEntry {
1191            entry_id: 2,
1192            node_id: 2,
1193            operation: "op2".to_string(),
1194            start_time: Instant::now(),
1195            duration: Duration::from_millis(20),
1196            input_ids: vec![1],
1197            output_ids: vec![3],
1198            metadata: HashMap::new(),
1199        });
1200
1201        let critical_path = trace.critical_path();
1202        // Critical path should be op0 -> op2 (the longer path)
1203        assert_eq!(critical_path.len(), 2);
1204        assert_eq!(critical_path[0].node_id, 0);
1205        assert_eq!(critical_path[1].node_id, 2);
1206
1207        let cp_duration = trace.critical_path_duration();
1208        assert_eq!(cp_duration, Duration::from_millis(30)); // 10 + 20
1209    }
1210
1211    #[test]
1212    fn test_critical_path_diamond_pattern() {
1213        // Test diamond pattern: op0 -> (op1, op2) -> op3
1214        let mut trace = ExecutionTrace::new();
1215
1216        // op0: produces tensor 0
1217        trace.add_entry(TraceEntry {
1218            entry_id: 0,
1219            node_id: 0,
1220            operation: "op0".to_string(),
1221            start_time: Instant::now(),
1222            duration: Duration::from_millis(10),
1223            input_ids: vec![],
1224            output_ids: vec![0],
1225            metadata: HashMap::new(),
1226        });
1227
1228        // op1: consumes tensor 0, produces tensor 1 (fast path)
1229        trace.add_entry(TraceEntry {
1230            entry_id: 1,
1231            node_id: 1,
1232            operation: "op1".to_string(),
1233            start_time: Instant::now(),
1234            duration: Duration::from_millis(5),
1235            input_ids: vec![0],
1236            output_ids: vec![1],
1237            metadata: HashMap::new(),
1238        });
1239
1240        // op2: consumes tensor 0, produces tensor 2 (slow path)
1241        trace.add_entry(TraceEntry {
1242            entry_id: 2,
1243            node_id: 2,
1244            operation: "op2".to_string(),
1245            start_time: Instant::now(),
1246            duration: Duration::from_millis(25),
1247            input_ids: vec![0],
1248            output_ids: vec![2],
1249            metadata: HashMap::new(),
1250        });
1251
1252        // op3: consumes tensors 1 and 2, produces tensor 3
1253        trace.add_entry(TraceEntry {
1254            entry_id: 3,
1255            node_id: 3,
1256            operation: "op3".to_string(),
1257            start_time: Instant::now(),
1258            duration: Duration::from_millis(15),
1259            input_ids: vec![1, 2],
1260            output_ids: vec![3],
1261            metadata: HashMap::new(),
1262        });
1263
1264        let critical_path = trace.critical_path();
1265        // Critical path should be op0 -> op2 -> op3 (longest path)
1266        assert_eq!(critical_path.len(), 3);
1267        assert_eq!(critical_path[0].node_id, 0);
1268        assert_eq!(critical_path[1].node_id, 2);
1269        assert_eq!(critical_path[2].node_id, 3);
1270
1271        let cp_duration = trace.critical_path_duration();
1272        assert_eq!(cp_duration, Duration::from_millis(50)); // 10 + 25 + 15
1273    }
1274
1275    #[test]
1276    fn test_critical_path_empty() {
1277        let trace = ExecutionTrace::new();
1278        let critical_path = trace.critical_path();
1279        assert_eq!(critical_path.len(), 0);
1280        assert_eq!(trace.critical_path_duration(), Duration::ZERO);
1281    }
1282
1283    #[test]
1284    fn test_critical_path_single_operation() {
1285        let mut trace = ExecutionTrace::new();
1286        trace.add_entry(TraceEntry {
1287            entry_id: 0,
1288            node_id: 0,
1289            operation: "op0".to_string(),
1290            start_time: Instant::now(),
1291            duration: Duration::from_millis(10),
1292            input_ids: vec![],
1293            output_ids: vec![0],
1294            metadata: HashMap::new(),
1295        });
1296
1297        let critical_path = trace.critical_path();
1298        assert_eq!(critical_path.len(), 1);
1299        assert_eq!(critical_path[0].node_id, 0);
1300    }
1301
1302    #[test]
1303    fn test_parallelism_factor() {
1304        let mut trace = ExecutionTrace::new();
1305
1306        // Create a scenario with some parallelism
1307        // Total work: 10 + 20 + 30 + 40 = 100ms
1308        // Critical path: 10 + 30 + 40 = 80ms
1309        // Parallelism factor: 100 / 80 = 1.25
1310
1311        // op0: start (10ms)
1312        trace.add_entry(TraceEntry {
1313            entry_id: 0,
1314            node_id: 0,
1315            operation: "op0".to_string(),
1316            start_time: Instant::now(),
1317            duration: Duration::from_millis(10),
1318            input_ids: vec![],
1319            output_ids: vec![0],
1320            metadata: HashMap::new(),
1321        });
1322
1323        // op1: parallel to op2 (20ms, fast path)
1324        trace.add_entry(TraceEntry {
1325            entry_id: 1,
1326            node_id: 1,
1327            operation: "op1".to_string(),
1328            start_time: Instant::now(),
1329            duration: Duration::from_millis(20),
1330            input_ids: vec![0],
1331            output_ids: vec![1],
1332            metadata: HashMap::new(),
1333        });
1334
1335        // op2: parallel to op1 (30ms, slow path - on critical path)
1336        trace.add_entry(TraceEntry {
1337            entry_id: 2,
1338            node_id: 2,
1339            operation: "op2".to_string(),
1340            start_time: Instant::now(),
1341            duration: Duration::from_millis(30),
1342            input_ids: vec![0],
1343            output_ids: vec![2],
1344            metadata: HashMap::new(),
1345        });
1346
1347        // op3: join (40ms)
1348        trace.add_entry(TraceEntry {
1349            entry_id: 3,
1350            node_id: 3,
1351            operation: "op3".to_string(),
1352            start_time: Instant::now(),
1353            duration: Duration::from_millis(40),
1354            input_ids: vec![1, 2],
1355            output_ids: vec![3],
1356            metadata: HashMap::new(),
1357        });
1358
1359        let parallelism = trace.parallelism_factor();
1360        // Total duration: 100ms
1361        // Critical path: 10 + 30 + 40 = 80ms
1362        // Factor: 100 / 80 = 1.25
1363        assert!((parallelism - 1.25).abs() < 0.01);
1364    }
1365
1366    #[test]
1367    fn test_critical_path_complex_graph() {
1368        // More complex dependency graph with multiple levels
1369        let mut trace = ExecutionTrace::new();
1370
1371        // Level 0: single root
1372        trace.add_entry(TraceEntry {
1373            entry_id: 0,
1374            node_id: 0,
1375            operation: "root".to_string(),
1376            start_time: Instant::now(),
1377            duration: Duration::from_millis(5),
1378            input_ids: vec![],
1379            output_ids: vec![0],
1380            metadata: HashMap::new(),
1381        });
1382
1383        // Level 1: three parallel branches
1384        for i in 1..=3 {
1385            trace.add_entry(TraceEntry {
1386                entry_id: i,
1387                node_id: i,
1388                operation: format!("branch{}", i),
1389                start_time: Instant::now(),
1390                duration: Duration::from_millis((i as u64) * 10),
1391                input_ids: vec![0],
1392                output_ids: vec![i],
1393                metadata: HashMap::new(),
1394            });
1395        }
1396
1397        // Level 2: merge two slowest branches
1398        trace.add_entry(TraceEntry {
1399            entry_id: 4,
1400            node_id: 4,
1401            operation: "merge".to_string(),
1402            start_time: Instant::now(),
1403            duration: Duration::from_millis(15),
1404            input_ids: vec![2, 3],
1405            output_ids: vec![4],
1406            metadata: HashMap::new(),
1407        });
1408
1409        let critical_path = trace.critical_path();
1410        // Critical path should be: root -> branch3 -> merge
1411        assert!(critical_path.len() >= 3);
1412        assert_eq!(critical_path[0].operation, "root");
1413        // The last entry should be "merge" (highest finish time)
1414        assert_eq!(critical_path[critical_path.len() - 1].operation, "merge");
1415    }
1416}