1use std::collections::HashMap;
34use std::fmt;
35use std::time::{Duration, Instant};
36
37#[derive(Debug, Clone)]
39pub struct TraceEntry {
40 pub entry_id: usize,
42 pub node_id: usize,
44 pub operation: String,
46 pub start_time: Instant,
48 pub duration: Duration,
50 pub input_ids: Vec<usize>,
52 pub output_ids: Vec<usize>,
54 pub metadata: HashMap<String, String>,
56}
57
58impl TraceEntry {
59 pub fn duration_ms(&self) -> f64 {
61 self.duration.as_secs_f64() * 1000.0
62 }
63
64 pub fn duration_us(&self) -> f64 {
66 self.duration.as_secs_f64() * 1_000_000.0
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ExecutionTrace {
73 entries: Vec<TraceEntry>,
74 total_duration: Duration,
75 graph_id: Option<usize>,
76}
77
78impl ExecutionTrace {
79 pub fn new() -> Self {
81 Self {
82 entries: Vec::new(),
83 total_duration: Duration::ZERO,
84 graph_id: None,
85 }
86 }
87
88 pub fn with_graph_id(mut self, graph_id: usize) -> Self {
90 self.graph_id = Some(graph_id);
91 self
92 }
93
94 pub fn add_entry(&mut self, entry: TraceEntry) {
96 self.total_duration += entry.duration;
97 self.entries.push(entry);
98 }
99
100 pub fn entries(&self) -> &[TraceEntry] {
102 &self.entries
103 }
104
105 pub fn total_duration(&self) -> Duration {
107 self.total_duration
108 }
109
110 pub fn total_duration_ms(&self) -> f64 {
112 self.total_duration.as_secs_f64() * 1000.0
113 }
114
115 pub fn entries_for_node(&self, node_id: usize) -> Vec<&TraceEntry> {
117 self.entries
118 .iter()
119 .filter(|e| e.node_id == node_id)
120 .collect()
121 }
122
123 pub fn critical_path(&self) -> Vec<&TraceEntry> {
129 if self.entries.is_empty() {
130 return Vec::new();
131 }
132
133 let n = self.entries.len();
134
135 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
137 for (idx, entry) in self.entries.iter().enumerate() {
138 for &output_id in &entry.output_ids {
139 tensor_producers.insert(output_id, idx);
140 }
141 }
142
143 let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
145 for (idx, entry) in self.entries.iter().enumerate() {
146 for &input_id in &entry.input_ids {
147 if let Some(&producer_idx) = tensor_producers.get(&input_id) {
148 if producer_idx < n {
149 predecessors[idx].push(producer_idx);
150 }
151 }
152 }
153 }
154
155 let mut eft = vec![Duration::ZERO; n];
158 let mut predecessor_on_critical_path = vec![None; n];
159
160 let mut changed = true;
162 for _ in 0..n {
163 if !changed {
165 break;
166 }
167 changed = false;
168
169 for idx in 0..n {
170 let mut max_pred_eft = Duration::ZERO;
171 let mut critical_pred = None;
172
173 for &pred_idx in &predecessors[idx] {
174 if eft[pred_idx] > max_pred_eft {
175 max_pred_eft = eft[pred_idx];
176 critical_pred = Some(pred_idx);
177 }
178 }
179
180 let new_eft = max_pred_eft + self.entries[idx].duration;
181 if new_eft > eft[idx] {
182 eft[idx] = new_eft;
183 predecessor_on_critical_path[idx] = critical_pred;
184 changed = true;
185 }
186 }
187 }
188
189 let critical_end_idx = eft
191 .iter()
192 .enumerate()
193 .max_by_key(|(_, &time)| time)
194 .map(|(idx, _)| idx)
195 .unwrap_or(0);
196
197 let mut critical_path_indices = Vec::new();
199 let mut current = Some(critical_end_idx);
200
201 while let Some(idx) = current {
202 critical_path_indices.push(idx);
203 current = predecessor_on_critical_path[idx];
204 }
205
206 critical_path_indices.reverse();
208
209 critical_path_indices
211 .iter()
212 .map(|&idx| &self.entries[idx])
213 .collect()
214 }
215
216 pub fn critical_path_duration(&self) -> Duration {
218 self.critical_path().iter().map(|e| e.duration).sum()
219 }
220
221 pub fn parallelism_factor(&self) -> f64 {
224 let critical_time = self.critical_path_duration();
225 if critical_time.as_secs_f64() == 0.0 {
226 return 1.0;
227 }
228 self.total_duration.as_secs_f64() / critical_time.as_secs_f64()
229 }
230
231 pub fn slowest_operations(&self, limit: usize) -> Vec<&TraceEntry> {
233 let mut sorted: Vec<_> = self.entries.iter().collect();
234 sorted.sort_by(|a, b| b.duration.cmp(&a.duration));
235 sorted.into_iter().take(limit).collect()
236 }
237
238 pub fn summary(&self) -> TraceSummary {
240 TraceSummary::from_trace(self)
241 }
242}
243
244impl Default for ExecutionTrace {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct TraceSummary {
253 pub total_operations: usize,
255 pub total_time_ms: f64,
257 pub avg_time_ms: f64,
259 pub max_time_ms: f64,
261 pub min_time_ms: f64,
263 pub operation_counts: HashMap<String, usize>,
265}
266
267impl TraceSummary {
268 pub fn from_trace(trace: &ExecutionTrace) -> Self {
270 let entries = trace.entries();
271 let total_operations = entries.len();
272
273 let total_time_ms = trace.total_duration_ms();
274 let avg_time_ms = if total_operations > 0 {
275 total_time_ms / total_operations as f64
276 } else {
277 0.0
278 };
279
280 let max_time_ms = entries.iter().map(|e| e.duration_ms()).fold(0.0, f64::max);
281 let min_time_ms = entries
282 .iter()
283 .map(|e| e.duration_ms())
284 .fold(f64::MAX, f64::min);
285
286 let mut operation_counts: HashMap<String, usize> = HashMap::new();
287 for entry in entries {
288 *operation_counts.entry(entry.operation.clone()).or_insert(0) += 1;
289 }
290
291 Self {
292 total_operations,
293 total_time_ms,
294 avg_time_ms,
295 max_time_ms,
296 min_time_ms,
297 operation_counts,
298 }
299 }
300}
301
302impl fmt::Display for TraceSummary {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 writeln!(f, "Execution Trace Summary")?;
305 writeln!(f, "=======================")?;
306 writeln!(f, "Total operations: {}", self.total_operations)?;
307 writeln!(f, "Total time: {:.2} ms", self.total_time_ms)?;
308 writeln!(f, "Average time: {:.2} ms", self.avg_time_ms)?;
309 writeln!(f, "Max time: {:.2} ms", self.max_time_ms)?;
310 writeln!(f, "Min time: {:.2} ms", self.min_time_ms)?;
311 writeln!(f, "\nOperation Counts:")?;
312 let mut sorted_ops: Vec<_> = self.operation_counts.iter().collect();
313 sorted_ops.sort_by_key(|(_, count)| std::cmp::Reverse(**count));
314 for (op, count) in sorted_ops {
315 writeln!(f, " {}: {}", op, count)?;
316 }
317 Ok(())
318 }
319}
320
321pub struct ExecutionTracer {
323 enabled: bool,
324 current_trace: ExecutionTrace,
325 traces: Vec<ExecutionTrace>,
326 next_entry_id: usize,
327}
328
329impl ExecutionTracer {
330 pub fn new() -> Self {
332 Self {
333 enabled: false,
334 current_trace: ExecutionTrace::new(),
335 traces: Vec::new(),
336 next_entry_id: 0,
337 }
338 }
339
340 pub fn enable(&mut self) {
342 self.enabled = true;
343 }
344
345 pub fn disable(&mut self) {
347 self.enabled = false;
348 }
349
350 pub fn is_enabled(&self) -> bool {
352 self.enabled
353 }
354
355 pub fn start_trace(&mut self, graph_id: Option<usize>) {
357 if !self.current_trace.entries.is_empty() {
358 self.finalize_trace();
359 }
360 self.current_trace = ExecutionTrace::new();
361 if let Some(id) = graph_id {
362 self.current_trace.graph_id = Some(id);
363 }
364 }
365
366 pub fn finalize_trace(&mut self) {
368 if !self.current_trace.entries.is_empty() {
369 let trace = std::mem::take(&mut self.current_trace);
370 self.traces.push(trace);
371 }
372 }
373
374 pub fn record_operation_start(
376 &mut self,
377 _node_id: usize,
378 _operation: impl Into<String>,
379 _input_ids: Vec<usize>,
380 ) -> OperationHandle {
381 if !self.enabled {
382 return OperationHandle {
383 entry_id: None,
384 start_time: Instant::now(),
385 };
386 }
387
388 let entry_id = self.next_entry_id;
389 self.next_entry_id += 1;
390
391 OperationHandle {
392 entry_id: Some(entry_id),
393 start_time: Instant::now(),
394 }
395 }
396
397 pub fn record_operation_end(
399 &mut self,
400 handle: OperationHandle,
401 node_id: usize,
402 operation: impl Into<String>,
403 input_ids: Vec<usize>,
404 output_ids: Vec<usize>,
405 metadata: HashMap<String, String>,
406 ) {
407 if !self.enabled || handle.entry_id.is_none() {
408 return;
409 }
410
411 let duration = handle.start_time.elapsed();
412 let entry = TraceEntry {
413 entry_id: handle.entry_id.unwrap(),
414 node_id,
415 operation: operation.into(),
416 start_time: handle.start_time,
417 duration,
418 input_ids,
419 output_ids,
420 metadata,
421 };
422
423 self.current_trace.add_entry(entry);
424 }
425
426 pub fn get_trace(&self) -> &ExecutionTrace {
428 &self.current_trace
429 }
430
431 pub fn get_all_traces(&self) -> &[ExecutionTrace] {
433 &self.traces
434 }
435
436 pub fn clear(&mut self) {
438 self.current_trace = ExecutionTrace::new();
439 self.traces.clear();
440 self.next_entry_id = 0;
441 }
442}
443
444impl Default for ExecutionTracer {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450pub struct OperationHandle {
452 entry_id: Option<usize>,
453 start_time: Instant,
454}
455
456#[derive(Debug, Clone)]
458pub struct TensorStats {
459 pub tensor_id: usize,
461 pub shape: Vec<usize>,
463 pub num_elements: usize,
465 pub dtype: String,
467 pub min_value: Option<f64>,
469 pub max_value: Option<f64>,
471 pub mean_value: Option<f64>,
473 pub std_dev: Option<f64>,
475 pub num_nans: Option<usize>,
477 pub num_infs: Option<usize>,
479}
480
481impl TensorStats {
482 pub fn new(tensor_id: usize, shape: Vec<usize>, dtype: impl Into<String>) -> Self {
484 let num_elements = shape.iter().product();
485 Self {
486 tensor_id,
487 shape,
488 num_elements,
489 dtype: dtype.into(),
490 min_value: None,
491 max_value: None,
492 mean_value: None,
493 std_dev: None,
494 num_nans: None,
495 num_infs: None,
496 }
497 }
498
499 pub fn with_statistics(
501 mut self,
502 min: f64,
503 max: f64,
504 mean: f64,
505 std_dev: f64,
506 num_nans: usize,
507 num_infs: usize,
508 ) -> Self {
509 self.min_value = Some(min);
510 self.max_value = Some(max);
511 self.mean_value = Some(mean);
512 self.std_dev = Some(std_dev);
513 self.num_nans = Some(num_nans);
514 self.num_infs = Some(num_infs);
515 self
516 }
517
518 pub fn has_numerical_issues(&self) -> bool {
520 self.num_nans.unwrap_or(0) > 0 || self.num_infs.unwrap_or(0) > 0
521 }
522}
523
524impl fmt::Display for TensorStats {
525 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
526 writeln!(f, "Tensor {} Stats:", self.tensor_id)?;
527 writeln!(f, " Shape: {:?}", self.shape)?;
528 writeln!(f, " Elements: {}", self.num_elements)?;
529 writeln!(f, " DType: {}", self.dtype)?;
530 if let Some(min) = self.min_value {
531 writeln!(f, " Min: {:.6}", min)?;
532 }
533 if let Some(max) = self.max_value {
534 writeln!(f, " Max: {:.6}", max)?;
535 }
536 if let Some(mean) = self.mean_value {
537 writeln!(f, " Mean: {:.6}", mean)?;
538 }
539 if let Some(std) = self.std_dev {
540 writeln!(f, " Std Dev: {:.6}", std)?;
541 }
542 if let Some(nans) = self.num_nans {
543 if nans > 0 {
544 writeln!(f, " ⚠️ NaNs: {}", nans)?;
545 }
546 }
547 if let Some(infs) = self.num_infs {
548 if infs > 0 {
549 writeln!(f, " ⚠️ Infs: {}", infs)?;
550 }
551 }
552 Ok(())
553 }
554}
555
556pub struct TensorInspector {
558 enabled: bool,
559 tensor_stats: HashMap<usize, TensorStats>,
560 watch_list: Vec<usize>,
561}
562
563impl TensorInspector {
564 pub fn new() -> Self {
566 Self {
567 enabled: false,
568 tensor_stats: HashMap::new(),
569 watch_list: Vec::new(),
570 }
571 }
572
573 pub fn enable(&mut self) {
575 self.enabled = true;
576 }
577
578 pub fn disable(&mut self) {
580 self.enabled = false;
581 }
582
583 pub fn is_enabled(&self) -> bool {
585 self.enabled
586 }
587
588 pub fn watch(&mut self, tensor_id: usize) {
590 if !self.watch_list.contains(&tensor_id) {
591 self.watch_list.push(tensor_id);
592 }
593 }
594
595 pub fn unwatch(&mut self, tensor_id: usize) {
597 self.watch_list.retain(|&id| id != tensor_id);
598 }
599
600 pub fn clear_watch_list(&mut self) {
602 self.watch_list.clear();
603 }
604
605 pub fn should_inspect(&self, tensor_id: usize) -> bool {
607 self.enabled && (self.watch_list.is_empty() || self.watch_list.contains(&tensor_id))
608 }
609
610 pub fn record_stats(&mut self, stats: TensorStats) {
612 if !self.enabled {
613 return;
614 }
615 self.tensor_stats.insert(stats.tensor_id, stats);
616 }
617
618 pub fn get_stats(&self, tensor_id: usize) -> Option<&TensorStats> {
620 self.tensor_stats.get(&tensor_id)
621 }
622
623 pub fn get_all_stats(&self) -> &HashMap<usize, TensorStats> {
625 &self.tensor_stats
626 }
627
628 pub fn find_problematic_tensors(&self) -> Vec<&TensorStats> {
630 self.tensor_stats
631 .values()
632 .filter(|stats| stats.has_numerical_issues())
633 .collect()
634 }
635
636 pub fn clear(&mut self) {
638 self.tensor_stats.clear();
639 }
640}
641
642impl Default for TensorInspector {
643 fn default() -> Self {
644 Self::new()
645 }
646}
647
648#[derive(Debug, Clone, PartialEq, Eq)]
650pub enum Breakpoint {
651 Node(usize),
653 Operation(String),
655 NumericalIssue,
657 TimeThreshold(u64),
659 Conditional(String), }
662
663#[derive(Debug, Clone)]
665pub struct BreakpointHit {
666 pub breakpoint: Breakpoint,
668 pub node_id: usize,
670 pub elapsed_us: u64,
672 pub context: HashMap<String, String>,
674}
675
676pub struct BreakpointManager {
678 enabled: bool,
679 breakpoints: Vec<Breakpoint>,
680 hits: Vec<BreakpointHit>,
681 continue_execution: bool,
682}
683
684impl BreakpointManager {
685 pub fn new() -> Self {
687 Self {
688 enabled: false,
689 breakpoints: Vec::new(),
690 hits: Vec::new(),
691 continue_execution: true,
692 }
693 }
694
695 pub fn enable(&mut self) {
697 self.enabled = true;
698 }
699
700 pub fn disable(&mut self) {
702 self.enabled = false;
703 }
704
705 pub fn is_enabled(&self) -> bool {
707 self.enabled
708 }
709
710 pub fn add_node_breakpoint(&mut self, node_id: usize) {
712 self.breakpoints.push(Breakpoint::Node(node_id));
713 }
714
715 pub fn add_operation_breakpoint(&mut self, operation: impl Into<String>) {
717 self.breakpoints
718 .push(Breakpoint::Operation(operation.into()));
719 }
720
721 pub fn add_numerical_issue_breakpoint(&mut self) {
723 self.breakpoints.push(Breakpoint::NumericalIssue);
724 }
725
726 pub fn add_time_threshold_breakpoint(&mut self, threshold_us: u64) {
728 self.breakpoints
729 .push(Breakpoint::TimeThreshold(threshold_us));
730 }
731
732 pub fn remove_breakpoint(&mut self, breakpoint: &Breakpoint) {
734 self.breakpoints.retain(|bp| bp != breakpoint);
735 }
736
737 pub fn clear_breakpoints(&mut self) {
739 self.breakpoints.clear();
740 }
741
742 pub fn get_breakpoints(&self) -> &[Breakpoint] {
744 &self.breakpoints
745 }
746
747 pub fn should_break(
749 &mut self,
750 node_id: usize,
751 operation: &str,
752 elapsed_us: u64,
753 has_numerical_issue: bool,
754 ) -> Option<BreakpointHit> {
755 if !self.enabled || !self.continue_execution {
756 return None;
757 }
758
759 for breakpoint in &self.breakpoints {
760 let should_break = match breakpoint {
761 Breakpoint::Node(bp_node_id) => *bp_node_id == node_id,
762 Breakpoint::Operation(bp_op) => bp_op == operation,
763 Breakpoint::NumericalIssue => has_numerical_issue,
764 Breakpoint::TimeThreshold(threshold) => elapsed_us > *threshold,
765 Breakpoint::Conditional(_) => false, };
767
768 if should_break {
769 let hit = BreakpointHit {
770 breakpoint: breakpoint.clone(),
771 node_id,
772 elapsed_us,
773 context: HashMap::new(),
774 };
775 self.hits.push(hit.clone());
776 self.continue_execution = false;
777 return Some(hit);
778 }
779 }
780
781 None
782 }
783
784 pub fn continue_execution(&mut self) {
786 self.continue_execution = true;
787 }
788
789 pub fn get_hits(&self) -> &[BreakpointHit] {
791 &self.hits
792 }
793
794 pub fn clear_hits(&mut self) {
796 self.hits.clear();
797 }
798}
799
800impl Default for BreakpointManager {
801 fn default() -> Self {
802 Self::new()
803 }
804}
805
806pub struct ExecutionRecorder {
808 enabled: bool,
809 tracer: ExecutionTracer,
810 inspector: TensorInspector,
811 breakpoints: BreakpointManager,
812}
813
814impl ExecutionRecorder {
815 pub fn new() -> Self {
817 Self {
818 enabled: false,
819 tracer: ExecutionTracer::new(),
820 inspector: TensorInspector::new(),
821 breakpoints: BreakpointManager::new(),
822 }
823 }
824
825 pub fn enable(&mut self) {
827 self.enabled = true;
828 self.tracer.enable();
829 self.inspector.enable();
830 self.breakpoints.enable();
831 }
832
833 pub fn disable(&mut self) {
835 self.enabled = false;
836 self.tracer.disable();
837 self.inspector.disable();
838 self.breakpoints.disable();
839 }
840
841 pub fn tracer(&mut self) -> &mut ExecutionTracer {
843 &mut self.tracer
844 }
845
846 pub fn inspector(&mut self) -> &mut TensorInspector {
848 &mut self.inspector
849 }
850
851 pub fn breakpoints(&mut self) -> &mut BreakpointManager {
853 &mut self.breakpoints
854 }
855
856 pub fn clear(&mut self) {
858 self.tracer.clear();
859 self.inspector.clear();
860 self.breakpoints.clear_hits();
861 }
862
863 pub fn generate_report(&self) -> ExecutionReport {
865 ExecutionReport {
866 trace_summary: self.tracer.get_trace().summary(),
867 problematic_tensors: self.inspector.find_problematic_tensors().len(),
868 breakpoint_hits: self.breakpoints.get_hits().len(),
869 }
870 }
871}
872
873impl Default for ExecutionRecorder {
874 fn default() -> Self {
875 Self::new()
876 }
877}
878
879#[derive(Debug, Clone)]
881pub struct ExecutionReport {
882 pub trace_summary: TraceSummary,
884 pub problematic_tensors: usize,
886 pub breakpoint_hits: usize,
888}
889
890impl fmt::Display for ExecutionReport {
891 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
892 writeln!(f, "{}", self.trace_summary)?;
893 writeln!(f, "\nDebug Information:")?;
894 writeln!(f, " Problematic tensors: {}", self.problematic_tensors)?;
895 writeln!(f, " Breakpoint hits: {}", self.breakpoint_hits)?;
896 Ok(())
897 }
898}
899
900#[cfg(test)]
901mod tests {
902 use super::*;
903
904 #[test]
905 fn test_execution_tracer() {
906 let mut tracer = ExecutionTracer::new();
907 assert!(!tracer.is_enabled());
908
909 tracer.enable();
910 assert!(tracer.is_enabled());
911
912 tracer.start_trace(Some(1));
913 let handle = tracer.record_operation_start(0, "einsum", vec![0, 1]);
914 std::thread::sleep(Duration::from_millis(10));
915 tracer.record_operation_end(handle, 0, "einsum", vec![0, 1], vec![2], HashMap::new());
916
917 let trace = tracer.get_trace();
918 assert_eq!(trace.entries().len(), 1);
919 assert!(trace.total_duration_ms() >= 10.0);
920 }
921
922 #[test]
923 fn test_trace_summary() {
924 let mut trace = ExecutionTrace::new();
925 let entry = TraceEntry {
926 entry_id: 0,
927 node_id: 0,
928 operation: "einsum".to_string(),
929 start_time: Instant::now(),
930 duration: Duration::from_millis(10),
931 input_ids: vec![0],
932 output_ids: vec![1],
933 metadata: HashMap::new(),
934 };
935 trace.add_entry(entry);
936
937 let summary = trace.summary();
938 assert_eq!(summary.total_operations, 1);
939 assert!(summary.total_time_ms >= 10.0);
940 }
941
942 #[test]
943 fn test_tensor_inspector() {
944 let mut inspector = TensorInspector::new();
945 inspector.enable();
946
947 let stats =
948 TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
949
950 inspector.record_stats(stats.clone());
951 assert_eq!(inspector.get_stats(0).unwrap().tensor_id, 0);
952 assert!(!stats.has_numerical_issues());
953 }
954
955 #[test]
956 fn test_tensor_numerical_issues() {
957 let stats = TensorStats::new(0, vec![2, 3], "f64").with_statistics(
958 0.0,
959 f64::INFINITY,
960 0.5,
961 0.25,
962 1,
963 1,
964 );
965
966 assert!(stats.has_numerical_issues());
967 }
968
969 #[test]
970 fn test_breakpoint_manager() {
971 let mut manager = BreakpointManager::new();
972 manager.enable();
973 manager.add_node_breakpoint(5);
974
975 let hit = manager.should_break(5, "einsum", 1000, false);
976 assert!(hit.is_some());
977 assert_eq!(hit.unwrap().node_id, 5);
978
979 manager.continue_execution();
980 let hit2 = manager.should_break(5, "einsum", 1000, false);
981 assert!(hit2.is_some());
982 }
983
984 #[test]
985 fn test_operation_breakpoint() {
986 let mut manager = BreakpointManager::new();
987 manager.enable();
988 manager.add_operation_breakpoint("matmul");
989
990 let hit = manager.should_break(1, "matmul", 1000, false);
991 assert!(hit.is_some());
992
993 let no_hit = manager.should_break(2, "add", 1000, false);
994 assert!(no_hit.is_none());
995 }
996
997 #[test]
998 fn test_time_threshold_breakpoint() {
999 let mut manager = BreakpointManager::new();
1000 manager.enable();
1001 manager.add_time_threshold_breakpoint(5000);
1002
1003 let no_hit = manager.should_break(1, "op", 4000, false);
1004 assert!(no_hit.is_none());
1005
1006 let hit = manager.should_break(1, "op", 6000, false);
1007 assert!(hit.is_some());
1008 }
1009
1010 #[test]
1011 fn test_numerical_issue_breakpoint() {
1012 let mut manager = BreakpointManager::new();
1013 manager.enable();
1014 manager.add_numerical_issue_breakpoint();
1015
1016 let no_hit = manager.should_break(1, "op", 1000, false);
1017 assert!(no_hit.is_none());
1018
1019 let hit = manager.should_break(1, "op", 1000, true);
1020 assert!(hit.is_some());
1021 }
1022
1023 #[test]
1024 fn test_execution_recorder() {
1025 let mut recorder = ExecutionRecorder::new();
1026 recorder.enable();
1027
1028 assert!(recorder.tracer().is_enabled());
1029 assert!(recorder.inspector().is_enabled());
1030 assert!(recorder.breakpoints().is_enabled());
1031
1032 recorder.clear();
1033 let report = recorder.generate_report();
1034 assert_eq!(report.trace_summary.total_operations, 0);
1035 }
1036
1037 #[test]
1038 fn test_slowest_operations() {
1039 let mut trace = ExecutionTrace::new();
1040 for i in 0..5 {
1041 let entry = TraceEntry {
1042 entry_id: i,
1043 node_id: i,
1044 operation: format!("op{}", i),
1045 start_time: Instant::now(),
1046 duration: Duration::from_millis((i as u64 + 1) * 10),
1047 input_ids: vec![],
1048 output_ids: vec![],
1049 metadata: HashMap::new(),
1050 };
1051 trace.add_entry(entry);
1052 }
1053
1054 let slowest = trace.slowest_operations(3);
1055 assert_eq!(slowest.len(), 3);
1056 assert_eq!(slowest[0].node_id, 4); assert_eq!(slowest[1].node_id, 3);
1058 assert_eq!(slowest[2].node_id, 2);
1059 }
1060
1061 #[test]
1062 fn test_watch_list() {
1063 let mut inspector = TensorInspector::new();
1064 inspector.enable();
1065
1066 inspector.watch(1);
1067 inspector.watch(2);
1068
1069 assert!(inspector.should_inspect(1));
1070 assert!(inspector.should_inspect(2));
1071 assert!(!inspector.should_inspect(3));
1072
1073 inspector.unwatch(1);
1074 assert!(!inspector.should_inspect(1));
1075 assert!(inspector.should_inspect(2));
1076
1077 inspector.clear_watch_list();
1078 assert!(inspector.should_inspect(5));
1080 }
1081
1082 #[test]
1083 fn test_trace_entries_for_node() {
1084 let mut trace = ExecutionTrace::new();
1085 for i in 0..3 {
1086 let entry = TraceEntry {
1087 entry_id: i,
1088 node_id: i % 2,
1089 operation: "op".to_string(),
1090 start_time: Instant::now(),
1091 duration: Duration::from_millis(10),
1092 input_ids: vec![],
1093 output_ids: vec![],
1094 metadata: HashMap::new(),
1095 };
1096 trace.add_entry(entry);
1097 }
1098
1099 let node_0_entries = trace.entries_for_node(0);
1100 assert_eq!(node_0_entries.len(), 2);
1101
1102 let node_1_entries = trace.entries_for_node(1);
1103 assert_eq!(node_1_entries.len(), 1);
1104 }
1105
1106 #[test]
1107 fn test_critical_path_linear_chain() {
1108 let mut trace = ExecutionTrace::new();
1110
1111 trace.add_entry(TraceEntry {
1113 entry_id: 0,
1114 node_id: 0,
1115 operation: "op0".to_string(),
1116 start_time: Instant::now(),
1117 duration: Duration::from_millis(10),
1118 input_ids: vec![],
1119 output_ids: vec![0],
1120 metadata: HashMap::new(),
1121 });
1122
1123 trace.add_entry(TraceEntry {
1125 entry_id: 1,
1126 node_id: 1,
1127 operation: "op1".to_string(),
1128 start_time: Instant::now(),
1129 duration: Duration::from_millis(20),
1130 input_ids: vec![0],
1131 output_ids: vec![1],
1132 metadata: HashMap::new(),
1133 });
1134
1135 trace.add_entry(TraceEntry {
1137 entry_id: 2,
1138 node_id: 2,
1139 operation: "op2".to_string(),
1140 start_time: Instant::now(),
1141 duration: Duration::from_millis(15),
1142 input_ids: vec![1],
1143 output_ids: vec![2],
1144 metadata: HashMap::new(),
1145 });
1146
1147 let critical_path = trace.critical_path();
1148 assert_eq!(critical_path.len(), 3); assert_eq!(critical_path[0].node_id, 0);
1150 assert_eq!(critical_path[1].node_id, 1);
1151 assert_eq!(critical_path[2].node_id, 2);
1152
1153 let cp_duration = trace.critical_path_duration();
1154 assert_eq!(cp_duration, Duration::from_millis(45)); }
1156
1157 #[test]
1158 fn test_critical_path_parallel_operations() {
1159 let mut trace = ExecutionTrace::new();
1161
1162 trace.add_entry(TraceEntry {
1164 entry_id: 0,
1165 node_id: 0,
1166 operation: "op0".to_string(),
1167 start_time: Instant::now(),
1168 duration: Duration::from_millis(10),
1169 input_ids: vec![],
1170 output_ids: vec![0, 1],
1171 metadata: HashMap::new(),
1172 });
1173
1174 trace.add_entry(TraceEntry {
1176 entry_id: 1,
1177 node_id: 1,
1178 operation: "op1".to_string(),
1179 start_time: Instant::now(),
1180 duration: Duration::from_millis(5),
1181 input_ids: vec![0],
1182 output_ids: vec![2],
1183 metadata: HashMap::new(),
1184 });
1185
1186 trace.add_entry(TraceEntry {
1188 entry_id: 2,
1189 node_id: 2,
1190 operation: "op2".to_string(),
1191 start_time: Instant::now(),
1192 duration: Duration::from_millis(20),
1193 input_ids: vec![1],
1194 output_ids: vec![3],
1195 metadata: HashMap::new(),
1196 });
1197
1198 let critical_path = trace.critical_path();
1199 assert_eq!(critical_path.len(), 2);
1201 assert_eq!(critical_path[0].node_id, 0);
1202 assert_eq!(critical_path[1].node_id, 2);
1203
1204 let cp_duration = trace.critical_path_duration();
1205 assert_eq!(cp_duration, Duration::from_millis(30)); }
1207
1208 #[test]
1209 fn test_critical_path_diamond_pattern() {
1210 let mut trace = ExecutionTrace::new();
1212
1213 trace.add_entry(TraceEntry {
1215 entry_id: 0,
1216 node_id: 0,
1217 operation: "op0".to_string(),
1218 start_time: Instant::now(),
1219 duration: Duration::from_millis(10),
1220 input_ids: vec![],
1221 output_ids: vec![0],
1222 metadata: HashMap::new(),
1223 });
1224
1225 trace.add_entry(TraceEntry {
1227 entry_id: 1,
1228 node_id: 1,
1229 operation: "op1".to_string(),
1230 start_time: Instant::now(),
1231 duration: Duration::from_millis(5),
1232 input_ids: vec![0],
1233 output_ids: vec![1],
1234 metadata: HashMap::new(),
1235 });
1236
1237 trace.add_entry(TraceEntry {
1239 entry_id: 2,
1240 node_id: 2,
1241 operation: "op2".to_string(),
1242 start_time: Instant::now(),
1243 duration: Duration::from_millis(25),
1244 input_ids: vec![0],
1245 output_ids: vec![2],
1246 metadata: HashMap::new(),
1247 });
1248
1249 trace.add_entry(TraceEntry {
1251 entry_id: 3,
1252 node_id: 3,
1253 operation: "op3".to_string(),
1254 start_time: Instant::now(),
1255 duration: Duration::from_millis(15),
1256 input_ids: vec![1, 2],
1257 output_ids: vec![3],
1258 metadata: HashMap::new(),
1259 });
1260
1261 let critical_path = trace.critical_path();
1262 assert_eq!(critical_path.len(), 3);
1264 assert_eq!(critical_path[0].node_id, 0);
1265 assert_eq!(critical_path[1].node_id, 2);
1266 assert_eq!(critical_path[2].node_id, 3);
1267
1268 let cp_duration = trace.critical_path_duration();
1269 assert_eq!(cp_duration, Duration::from_millis(50)); }
1271
1272 #[test]
1273 fn test_critical_path_empty() {
1274 let trace = ExecutionTrace::new();
1275 let critical_path = trace.critical_path();
1276 assert_eq!(critical_path.len(), 0);
1277 assert_eq!(trace.critical_path_duration(), Duration::ZERO);
1278 }
1279
1280 #[test]
1281 fn test_critical_path_single_operation() {
1282 let mut trace = ExecutionTrace::new();
1283 trace.add_entry(TraceEntry {
1284 entry_id: 0,
1285 node_id: 0,
1286 operation: "op0".to_string(),
1287 start_time: Instant::now(),
1288 duration: Duration::from_millis(10),
1289 input_ids: vec![],
1290 output_ids: vec![0],
1291 metadata: HashMap::new(),
1292 });
1293
1294 let critical_path = trace.critical_path();
1295 assert_eq!(critical_path.len(), 1);
1296 assert_eq!(critical_path[0].node_id, 0);
1297 }
1298
1299 #[test]
1300 fn test_parallelism_factor() {
1301 let mut trace = ExecutionTrace::new();
1302
1303 trace.add_entry(TraceEntry {
1310 entry_id: 0,
1311 node_id: 0,
1312 operation: "op0".to_string(),
1313 start_time: Instant::now(),
1314 duration: Duration::from_millis(10),
1315 input_ids: vec![],
1316 output_ids: vec![0],
1317 metadata: HashMap::new(),
1318 });
1319
1320 trace.add_entry(TraceEntry {
1322 entry_id: 1,
1323 node_id: 1,
1324 operation: "op1".to_string(),
1325 start_time: Instant::now(),
1326 duration: Duration::from_millis(20),
1327 input_ids: vec![0],
1328 output_ids: vec![1],
1329 metadata: HashMap::new(),
1330 });
1331
1332 trace.add_entry(TraceEntry {
1334 entry_id: 2,
1335 node_id: 2,
1336 operation: "op2".to_string(),
1337 start_time: Instant::now(),
1338 duration: Duration::from_millis(30),
1339 input_ids: vec![0],
1340 output_ids: vec![2],
1341 metadata: HashMap::new(),
1342 });
1343
1344 trace.add_entry(TraceEntry {
1346 entry_id: 3,
1347 node_id: 3,
1348 operation: "op3".to_string(),
1349 start_time: Instant::now(),
1350 duration: Duration::from_millis(40),
1351 input_ids: vec![1, 2],
1352 output_ids: vec![3],
1353 metadata: HashMap::new(),
1354 });
1355
1356 let parallelism = trace.parallelism_factor();
1357 assert!((parallelism - 1.25).abs() < 0.01);
1361 }
1362
1363 #[test]
1364 fn test_critical_path_complex_graph() {
1365 let mut trace = ExecutionTrace::new();
1367
1368 trace.add_entry(TraceEntry {
1370 entry_id: 0,
1371 node_id: 0,
1372 operation: "root".to_string(),
1373 start_time: Instant::now(),
1374 duration: Duration::from_millis(5),
1375 input_ids: vec![],
1376 output_ids: vec![0],
1377 metadata: HashMap::new(),
1378 });
1379
1380 for i in 1..=3 {
1382 trace.add_entry(TraceEntry {
1383 entry_id: i,
1384 node_id: i,
1385 operation: format!("branch{}", i),
1386 start_time: Instant::now(),
1387 duration: Duration::from_millis((i as u64) * 10),
1388 input_ids: vec![0],
1389 output_ids: vec![i],
1390 metadata: HashMap::new(),
1391 });
1392 }
1393
1394 trace.add_entry(TraceEntry {
1396 entry_id: 4,
1397 node_id: 4,
1398 operation: "merge".to_string(),
1399 start_time: Instant::now(),
1400 duration: Duration::from_millis(15),
1401 input_ids: vec![2, 3],
1402 output_ids: vec![4],
1403 metadata: HashMap::new(),
1404 });
1405
1406 let critical_path = trace.critical_path();
1407 assert!(critical_path.len() >= 3);
1409 assert_eq!(critical_path[0].operation, "root");
1410 assert_eq!(critical_path[critical_path.len() - 1].operation, "merge");
1412 }
1413}