1use std::cmp::Reverse;
34use std::collections::HashMap;
35use std::fmt;
36use std::time::{Duration, Instant};
37
38#[derive(Debug, Clone)]
40pub struct TraceEntry {
41 pub entry_id: usize,
43 pub node_id: usize,
45 pub operation: String,
47 pub start_time: Instant,
49 pub duration: Duration,
51 pub input_ids: Vec<usize>,
53 pub output_ids: Vec<usize>,
55 pub metadata: HashMap<String, String>,
57}
58
59impl TraceEntry {
60 pub fn duration_ms(&self) -> f64 {
62 self.duration.as_secs_f64() * 1000.0
63 }
64
65 pub fn duration_us(&self) -> f64 {
67 self.duration.as_secs_f64() * 1_000_000.0
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct ExecutionTrace {
74 entries: Vec<TraceEntry>,
75 total_duration: Duration,
76 graph_id: Option<usize>,
77}
78
79impl ExecutionTrace {
80 pub fn new() -> Self {
82 Self {
83 entries: Vec::new(),
84 total_duration: Duration::ZERO,
85 graph_id: None,
86 }
87 }
88
89 pub fn with_graph_id(mut self, graph_id: usize) -> Self {
91 self.graph_id = Some(graph_id);
92 self
93 }
94
95 pub fn add_entry(&mut self, entry: TraceEntry) {
97 self.total_duration += entry.duration;
98 self.entries.push(entry);
99 }
100
101 pub fn entries(&self) -> &[TraceEntry] {
103 &self.entries
104 }
105
106 pub fn total_duration(&self) -> Duration {
108 self.total_duration
109 }
110
111 pub fn total_duration_ms(&self) -> f64 {
113 self.total_duration.as_secs_f64() * 1000.0
114 }
115
116 pub fn entries_for_node(&self, node_id: usize) -> Vec<&TraceEntry> {
118 self.entries
119 .iter()
120 .filter(|e| e.node_id == node_id)
121 .collect()
122 }
123
124 pub fn critical_path(&self) -> Vec<&TraceEntry> {
130 if self.entries.is_empty() {
131 return Vec::new();
132 }
133
134 let n = self.entries.len();
135
136 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
138 for (idx, entry) in self.entries.iter().enumerate() {
139 for &output_id in &entry.output_ids {
140 tensor_producers.insert(output_id, idx);
141 }
142 }
143
144 let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
146 for (idx, entry) in self.entries.iter().enumerate() {
147 for &input_id in &entry.input_ids {
148 if let Some(&producer_idx) = tensor_producers.get(&input_id) {
149 if producer_idx < n {
150 predecessors[idx].push(producer_idx);
151 }
152 }
153 }
154 }
155
156 let mut eft = vec![Duration::ZERO; n];
159 let mut predecessor_on_critical_path = vec![None; n];
160
161 let mut changed = true;
163 for _ in 0..n {
164 if !changed {
166 break;
167 }
168 changed = false;
169
170 for idx in 0..n {
171 let mut max_pred_eft = Duration::ZERO;
172 let mut critical_pred = None;
173
174 for &pred_idx in &predecessors[idx] {
175 if eft[pred_idx] > max_pred_eft {
176 max_pred_eft = eft[pred_idx];
177 critical_pred = Some(pred_idx);
178 }
179 }
180
181 let new_eft = max_pred_eft + self.entries[idx].duration;
182 if new_eft > eft[idx] {
183 eft[idx] = new_eft;
184 predecessor_on_critical_path[idx] = critical_pred;
185 changed = true;
186 }
187 }
188 }
189
190 let critical_end_idx = eft
192 .iter()
193 .enumerate()
194 .max_by_key(|(_, &time)| time)
195 .map(|(idx, _)| idx)
196 .unwrap_or(0);
197
198 let mut critical_path_indices = Vec::new();
200 let mut current = Some(critical_end_idx);
201
202 while let Some(idx) = current {
203 critical_path_indices.push(idx);
204 current = predecessor_on_critical_path[idx];
205 }
206
207 critical_path_indices.reverse();
209
210 critical_path_indices
212 .iter()
213 .map(|&idx| &self.entries[idx])
214 .collect()
215 }
216
217 pub fn critical_path_duration(&self) -> Duration {
219 self.critical_path().iter().map(|e| e.duration).sum()
220 }
221
222 pub fn parallelism_factor(&self) -> f64 {
225 let critical_time = self.critical_path_duration();
226 if critical_time.as_secs_f64() == 0.0 {
227 return 1.0;
228 }
229 self.total_duration.as_secs_f64() / critical_time.as_secs_f64()
230 }
231
232 pub fn slowest_operations(&self, limit: usize) -> Vec<&TraceEntry> {
234 let mut sorted: Vec<_> = self.entries.iter().collect();
235 sorted.sort_by_key(|b| Reverse(b.duration));
236 sorted.into_iter().take(limit).collect()
237 }
238
239 pub fn summary(&self) -> TraceSummary {
241 TraceSummary::from_trace(self)
242 }
243}
244
245impl Default for ExecutionTrace {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct TraceSummary {
254 pub total_operations: usize,
256 pub total_time_ms: f64,
258 pub avg_time_ms: f64,
260 pub max_time_ms: f64,
262 pub min_time_ms: f64,
264 pub operation_counts: HashMap<String, usize>,
266}
267
268impl TraceSummary {
269 pub fn from_trace(trace: &ExecutionTrace) -> Self {
271 let entries = trace.entries();
272 let total_operations = entries.len();
273
274 let total_time_ms = trace.total_duration_ms();
275 let avg_time_ms = if total_operations > 0 {
276 total_time_ms / total_operations as f64
277 } else {
278 0.0
279 };
280
281 let max_time_ms = entries.iter().map(|e| e.duration_ms()).fold(0.0, f64::max);
282 let min_time_ms = entries
283 .iter()
284 .map(|e| e.duration_ms())
285 .fold(f64::MAX, f64::min);
286
287 let mut operation_counts: HashMap<String, usize> = HashMap::new();
288 for entry in entries {
289 *operation_counts.entry(entry.operation.clone()).or_insert(0) += 1;
290 }
291
292 Self {
293 total_operations,
294 total_time_ms,
295 avg_time_ms,
296 max_time_ms,
297 min_time_ms,
298 operation_counts,
299 }
300 }
301}
302
303impl fmt::Display for TraceSummary {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 writeln!(f, "Execution Trace Summary")?;
306 writeln!(f, "=======================")?;
307 writeln!(f, "Total operations: {}", self.total_operations)?;
308 writeln!(f, "Total time: {:.2} ms", self.total_time_ms)?;
309 writeln!(f, "Average time: {:.2} ms", self.avg_time_ms)?;
310 writeln!(f, "Max time: {:.2} ms", self.max_time_ms)?;
311 writeln!(f, "Min time: {:.2} ms", self.min_time_ms)?;
312 writeln!(f, "\nOperation Counts:")?;
313 let mut sorted_ops: Vec<_> = self.operation_counts.iter().collect();
314 sorted_ops.sort_by_key(|(_, count)| std::cmp::Reverse(**count));
315 for (op, count) in sorted_ops {
316 writeln!(f, " {}: {}", op, count)?;
317 }
318 Ok(())
319 }
320}
321
322pub struct ExecutionTracer {
324 enabled: bool,
325 current_trace: ExecutionTrace,
326 traces: Vec<ExecutionTrace>,
327 next_entry_id: usize,
328}
329
330impl ExecutionTracer {
331 pub fn new() -> Self {
333 Self {
334 enabled: false,
335 current_trace: ExecutionTrace::new(),
336 traces: Vec::new(),
337 next_entry_id: 0,
338 }
339 }
340
341 pub fn enable(&mut self) {
343 self.enabled = true;
344 }
345
346 pub fn disable(&mut self) {
348 self.enabled = false;
349 }
350
351 pub fn is_enabled(&self) -> bool {
353 self.enabled
354 }
355
356 pub fn start_trace(&mut self, graph_id: Option<usize>) {
358 if !self.current_trace.entries.is_empty() {
359 self.finalize_trace();
360 }
361 self.current_trace = ExecutionTrace::new();
362 if let Some(id) = graph_id {
363 self.current_trace.graph_id = Some(id);
364 }
365 }
366
367 pub fn finalize_trace(&mut self) {
369 if !self.current_trace.entries.is_empty() {
370 let trace = std::mem::take(&mut self.current_trace);
371 self.traces.push(trace);
372 }
373 }
374
375 pub fn record_operation_start(
377 &mut self,
378 _node_id: usize,
379 _operation: impl Into<String>,
380 _input_ids: Vec<usize>,
381 ) -> OperationHandle {
382 if !self.enabled {
383 return OperationHandle {
384 entry_id: None,
385 start_time: Instant::now(),
386 };
387 }
388
389 let entry_id = self.next_entry_id;
390 self.next_entry_id += 1;
391
392 OperationHandle {
393 entry_id: Some(entry_id),
394 start_time: Instant::now(),
395 }
396 }
397
398 pub fn record_operation_end(
400 &mut self,
401 handle: OperationHandle,
402 node_id: usize,
403 operation: impl Into<String>,
404 input_ids: Vec<usize>,
405 output_ids: Vec<usize>,
406 metadata: HashMap<String, String>,
407 ) {
408 if !self.enabled || handle.entry_id.is_none() {
409 return;
410 }
411
412 let duration = handle.start_time.elapsed();
413 let entry = TraceEntry {
414 entry_id: handle
415 .entry_id
416 .expect("entry_id is Some after is_none guard above"),
417 node_id,
418 operation: operation.into(),
419 start_time: handle.start_time,
420 duration,
421 input_ids,
422 output_ids,
423 metadata,
424 };
425
426 self.current_trace.add_entry(entry);
427 }
428
429 pub fn get_trace(&self) -> &ExecutionTrace {
431 &self.current_trace
432 }
433
434 pub fn get_all_traces(&self) -> &[ExecutionTrace] {
436 &self.traces
437 }
438
439 pub fn clear(&mut self) {
441 self.current_trace = ExecutionTrace::new();
442 self.traces.clear();
443 self.next_entry_id = 0;
444 }
445}
446
447impl Default for ExecutionTracer {
448 fn default() -> Self {
449 Self::new()
450 }
451}
452
453pub struct OperationHandle {
455 entry_id: Option<usize>,
456 start_time: Instant,
457}
458
459#[derive(Debug, Clone)]
461pub struct TensorStats {
462 pub tensor_id: usize,
464 pub shape: Vec<usize>,
466 pub num_elements: usize,
468 pub dtype: String,
470 pub min_value: Option<f64>,
472 pub max_value: Option<f64>,
474 pub mean_value: Option<f64>,
476 pub std_dev: Option<f64>,
478 pub num_nans: Option<usize>,
480 pub num_infs: Option<usize>,
482}
483
484impl TensorStats {
485 pub fn new(tensor_id: usize, shape: Vec<usize>, dtype: impl Into<String>) -> Self {
487 let num_elements = shape.iter().product();
488 Self {
489 tensor_id,
490 shape,
491 num_elements,
492 dtype: dtype.into(),
493 min_value: None,
494 max_value: None,
495 mean_value: None,
496 std_dev: None,
497 num_nans: None,
498 num_infs: None,
499 }
500 }
501
502 pub fn with_statistics(
504 mut self,
505 min: f64,
506 max: f64,
507 mean: f64,
508 std_dev: f64,
509 num_nans: usize,
510 num_infs: usize,
511 ) -> Self {
512 self.min_value = Some(min);
513 self.max_value = Some(max);
514 self.mean_value = Some(mean);
515 self.std_dev = Some(std_dev);
516 self.num_nans = Some(num_nans);
517 self.num_infs = Some(num_infs);
518 self
519 }
520
521 pub fn has_numerical_issues(&self) -> bool {
523 self.num_nans.unwrap_or(0) > 0 || self.num_infs.unwrap_or(0) > 0
524 }
525}
526
527impl fmt::Display for TensorStats {
528 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
529 writeln!(f, "Tensor {} Stats:", self.tensor_id)?;
530 writeln!(f, " Shape: {:?}", self.shape)?;
531 writeln!(f, " Elements: {}", self.num_elements)?;
532 writeln!(f, " DType: {}", self.dtype)?;
533 if let Some(min) = self.min_value {
534 writeln!(f, " Min: {:.6}", min)?;
535 }
536 if let Some(max) = self.max_value {
537 writeln!(f, " Max: {:.6}", max)?;
538 }
539 if let Some(mean) = self.mean_value {
540 writeln!(f, " Mean: {:.6}", mean)?;
541 }
542 if let Some(std) = self.std_dev {
543 writeln!(f, " Std Dev: {:.6}", std)?;
544 }
545 if let Some(nans) = self.num_nans {
546 if nans > 0 {
547 writeln!(f, " ⚠️ NaNs: {}", nans)?;
548 }
549 }
550 if let Some(infs) = self.num_infs {
551 if infs > 0 {
552 writeln!(f, " ⚠️ Infs: {}", infs)?;
553 }
554 }
555 Ok(())
556 }
557}
558
559pub struct TensorInspector {
561 enabled: bool,
562 tensor_stats: HashMap<usize, TensorStats>,
563 watch_list: Vec<usize>,
564}
565
566impl TensorInspector {
567 pub fn new() -> Self {
569 Self {
570 enabled: false,
571 tensor_stats: HashMap::new(),
572 watch_list: Vec::new(),
573 }
574 }
575
576 pub fn enable(&mut self) {
578 self.enabled = true;
579 }
580
581 pub fn disable(&mut self) {
583 self.enabled = false;
584 }
585
586 pub fn is_enabled(&self) -> bool {
588 self.enabled
589 }
590
591 pub fn watch(&mut self, tensor_id: usize) {
593 if !self.watch_list.contains(&tensor_id) {
594 self.watch_list.push(tensor_id);
595 }
596 }
597
598 pub fn unwatch(&mut self, tensor_id: usize) {
600 self.watch_list.retain(|&id| id != tensor_id);
601 }
602
603 pub fn clear_watch_list(&mut self) {
605 self.watch_list.clear();
606 }
607
608 pub fn should_inspect(&self, tensor_id: usize) -> bool {
610 self.enabled && (self.watch_list.is_empty() || self.watch_list.contains(&tensor_id))
611 }
612
613 pub fn record_stats(&mut self, stats: TensorStats) {
615 if !self.enabled {
616 return;
617 }
618 self.tensor_stats.insert(stats.tensor_id, stats);
619 }
620
621 pub fn get_stats(&self, tensor_id: usize) -> Option<&TensorStats> {
623 self.tensor_stats.get(&tensor_id)
624 }
625
626 pub fn get_all_stats(&self) -> &HashMap<usize, TensorStats> {
628 &self.tensor_stats
629 }
630
631 pub fn find_problematic_tensors(&self) -> Vec<&TensorStats> {
633 self.tensor_stats
634 .values()
635 .filter(|stats| stats.has_numerical_issues())
636 .collect()
637 }
638
639 pub fn clear(&mut self) {
641 self.tensor_stats.clear();
642 }
643}
644
645impl Default for TensorInspector {
646 fn default() -> Self {
647 Self::new()
648 }
649}
650
651#[derive(Debug, Clone, PartialEq, Eq)]
653pub enum Breakpoint {
654 Node(usize),
656 Operation(String),
658 NumericalIssue,
660 TimeThreshold(u64),
662 Conditional(String), }
665
666#[derive(Debug, Clone)]
668pub struct BreakpointHit {
669 pub breakpoint: Breakpoint,
671 pub node_id: usize,
673 pub elapsed_us: u64,
675 pub context: HashMap<String, String>,
677}
678
679pub struct BreakpointManager {
681 enabled: bool,
682 breakpoints: Vec<Breakpoint>,
683 hits: Vec<BreakpointHit>,
684 continue_execution: bool,
685}
686
687impl BreakpointManager {
688 pub fn new() -> Self {
690 Self {
691 enabled: false,
692 breakpoints: Vec::new(),
693 hits: Vec::new(),
694 continue_execution: true,
695 }
696 }
697
698 pub fn enable(&mut self) {
700 self.enabled = true;
701 }
702
703 pub fn disable(&mut self) {
705 self.enabled = false;
706 }
707
708 pub fn is_enabled(&self) -> bool {
710 self.enabled
711 }
712
713 pub fn add_node_breakpoint(&mut self, node_id: usize) {
715 self.breakpoints.push(Breakpoint::Node(node_id));
716 }
717
718 pub fn add_operation_breakpoint(&mut self, operation: impl Into<String>) {
720 self.breakpoints
721 .push(Breakpoint::Operation(operation.into()));
722 }
723
724 pub fn add_numerical_issue_breakpoint(&mut self) {
726 self.breakpoints.push(Breakpoint::NumericalIssue);
727 }
728
729 pub fn add_time_threshold_breakpoint(&mut self, threshold_us: u64) {
731 self.breakpoints
732 .push(Breakpoint::TimeThreshold(threshold_us));
733 }
734
735 pub fn remove_breakpoint(&mut self, breakpoint: &Breakpoint) {
737 self.breakpoints.retain(|bp| bp != breakpoint);
738 }
739
740 pub fn clear_breakpoints(&mut self) {
742 self.breakpoints.clear();
743 }
744
745 pub fn get_breakpoints(&self) -> &[Breakpoint] {
747 &self.breakpoints
748 }
749
750 pub fn should_break(
752 &mut self,
753 node_id: usize,
754 operation: &str,
755 elapsed_us: u64,
756 has_numerical_issue: bool,
757 ) -> Option<BreakpointHit> {
758 if !self.enabled || !self.continue_execution {
759 return None;
760 }
761
762 for breakpoint in &self.breakpoints {
763 let should_break = match breakpoint {
764 Breakpoint::Node(bp_node_id) => *bp_node_id == node_id,
765 Breakpoint::Operation(bp_op) => bp_op == operation,
766 Breakpoint::NumericalIssue => has_numerical_issue,
767 Breakpoint::TimeThreshold(threshold) => elapsed_us > *threshold,
768 Breakpoint::Conditional(_) => false, };
770
771 if should_break {
772 let hit = BreakpointHit {
773 breakpoint: breakpoint.clone(),
774 node_id,
775 elapsed_us,
776 context: HashMap::new(),
777 };
778 self.hits.push(hit.clone());
779 self.continue_execution = false;
780 return Some(hit);
781 }
782 }
783
784 None
785 }
786
787 pub fn continue_execution(&mut self) {
789 self.continue_execution = true;
790 }
791
792 pub fn get_hits(&self) -> &[BreakpointHit] {
794 &self.hits
795 }
796
797 pub fn clear_hits(&mut self) {
799 self.hits.clear();
800 }
801}
802
803impl Default for BreakpointManager {
804 fn default() -> Self {
805 Self::new()
806 }
807}
808
809pub struct ExecutionRecorder {
811 enabled: bool,
812 tracer: ExecutionTracer,
813 inspector: TensorInspector,
814 breakpoints: BreakpointManager,
815}
816
817impl ExecutionRecorder {
818 pub fn new() -> Self {
820 Self {
821 enabled: false,
822 tracer: ExecutionTracer::new(),
823 inspector: TensorInspector::new(),
824 breakpoints: BreakpointManager::new(),
825 }
826 }
827
828 pub fn enable(&mut self) {
830 self.enabled = true;
831 self.tracer.enable();
832 self.inspector.enable();
833 self.breakpoints.enable();
834 }
835
836 pub fn disable(&mut self) {
838 self.enabled = false;
839 self.tracer.disable();
840 self.inspector.disable();
841 self.breakpoints.disable();
842 }
843
844 pub fn tracer(&mut self) -> &mut ExecutionTracer {
846 &mut self.tracer
847 }
848
849 pub fn inspector(&mut self) -> &mut TensorInspector {
851 &mut self.inspector
852 }
853
854 pub fn breakpoints(&mut self) -> &mut BreakpointManager {
856 &mut self.breakpoints
857 }
858
859 pub fn clear(&mut self) {
861 self.tracer.clear();
862 self.inspector.clear();
863 self.breakpoints.clear_hits();
864 }
865
866 pub fn generate_report(&self) -> ExecutionReport {
868 ExecutionReport {
869 trace_summary: self.tracer.get_trace().summary(),
870 problematic_tensors: self.inspector.find_problematic_tensors().len(),
871 breakpoint_hits: self.breakpoints.get_hits().len(),
872 }
873 }
874}
875
876impl Default for ExecutionRecorder {
877 fn default() -> Self {
878 Self::new()
879 }
880}
881
882#[derive(Debug, Clone)]
884pub struct ExecutionReport {
885 pub trace_summary: TraceSummary,
887 pub problematic_tensors: usize,
889 pub breakpoint_hits: usize,
891}
892
893impl fmt::Display for ExecutionReport {
894 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
895 writeln!(f, "{}", self.trace_summary)?;
896 writeln!(f, "\nDebug Information:")?;
897 writeln!(f, " Problematic tensors: {}", self.problematic_tensors)?;
898 writeln!(f, " Breakpoint hits: {}", self.breakpoint_hits)?;
899 Ok(())
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 #[test]
908 fn test_execution_tracer() {
909 let mut tracer = ExecutionTracer::new();
910 assert!(!tracer.is_enabled());
911
912 tracer.enable();
913 assert!(tracer.is_enabled());
914
915 tracer.start_trace(Some(1));
916 let handle = tracer.record_operation_start(0, "einsum", vec![0, 1]);
917 std::thread::sleep(Duration::from_millis(10));
918 tracer.record_operation_end(handle, 0, "einsum", vec![0, 1], vec![2], HashMap::new());
919
920 let trace = tracer.get_trace();
921 assert_eq!(trace.entries().len(), 1);
922 assert!(trace.total_duration_ms() >= 10.0);
923 }
924
925 #[test]
926 fn test_trace_summary() {
927 let mut trace = ExecutionTrace::new();
928 let entry = TraceEntry {
929 entry_id: 0,
930 node_id: 0,
931 operation: "einsum".to_string(),
932 start_time: Instant::now(),
933 duration: Duration::from_millis(10),
934 input_ids: vec![0],
935 output_ids: vec![1],
936 metadata: HashMap::new(),
937 };
938 trace.add_entry(entry);
939
940 let summary = trace.summary();
941 assert_eq!(summary.total_operations, 1);
942 assert!(summary.total_time_ms >= 10.0);
943 }
944
945 #[test]
946 fn test_tensor_inspector() {
947 let mut inspector = TensorInspector::new();
948 inspector.enable();
949
950 let stats =
951 TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
952
953 inspector.record_stats(stats.clone());
954 assert_eq!(inspector.get_stats(0).expect("unwrap").tensor_id, 0);
955 assert!(!stats.has_numerical_issues());
956 }
957
958 #[test]
959 fn test_tensor_numerical_issues() {
960 let stats = TensorStats::new(0, vec![2, 3], "f64").with_statistics(
961 0.0,
962 f64::INFINITY,
963 0.5,
964 0.25,
965 1,
966 1,
967 );
968
969 assert!(stats.has_numerical_issues());
970 }
971
972 #[test]
973 fn test_breakpoint_manager() {
974 let mut manager = BreakpointManager::new();
975 manager.enable();
976 manager.add_node_breakpoint(5);
977
978 let hit = manager.should_break(5, "einsum", 1000, false);
979 assert!(hit.is_some());
980 assert_eq!(hit.expect("unwrap").node_id, 5);
981
982 manager.continue_execution();
983 let hit2 = manager.should_break(5, "einsum", 1000, false);
984 assert!(hit2.is_some());
985 }
986
987 #[test]
988 fn test_operation_breakpoint() {
989 let mut manager = BreakpointManager::new();
990 manager.enable();
991 manager.add_operation_breakpoint("matmul");
992
993 let hit = manager.should_break(1, "matmul", 1000, false);
994 assert!(hit.is_some());
995
996 let no_hit = manager.should_break(2, "add", 1000, false);
997 assert!(no_hit.is_none());
998 }
999
1000 #[test]
1001 fn test_time_threshold_breakpoint() {
1002 let mut manager = BreakpointManager::new();
1003 manager.enable();
1004 manager.add_time_threshold_breakpoint(5000);
1005
1006 let no_hit = manager.should_break(1, "op", 4000, false);
1007 assert!(no_hit.is_none());
1008
1009 let hit = manager.should_break(1, "op", 6000, false);
1010 assert!(hit.is_some());
1011 }
1012
1013 #[test]
1014 fn test_numerical_issue_breakpoint() {
1015 let mut manager = BreakpointManager::new();
1016 manager.enable();
1017 manager.add_numerical_issue_breakpoint();
1018
1019 let no_hit = manager.should_break(1, "op", 1000, false);
1020 assert!(no_hit.is_none());
1021
1022 let hit = manager.should_break(1, "op", 1000, true);
1023 assert!(hit.is_some());
1024 }
1025
1026 #[test]
1027 fn test_execution_recorder() {
1028 let mut recorder = ExecutionRecorder::new();
1029 recorder.enable();
1030
1031 assert!(recorder.tracer().is_enabled());
1032 assert!(recorder.inspector().is_enabled());
1033 assert!(recorder.breakpoints().is_enabled());
1034
1035 recorder.clear();
1036 let report = recorder.generate_report();
1037 assert_eq!(report.trace_summary.total_operations, 0);
1038 }
1039
1040 #[test]
1041 fn test_slowest_operations() {
1042 let mut trace = ExecutionTrace::new();
1043 for i in 0..5 {
1044 let entry = TraceEntry {
1045 entry_id: i,
1046 node_id: i,
1047 operation: format!("op{}", i),
1048 start_time: Instant::now(),
1049 duration: Duration::from_millis((i as u64 + 1) * 10),
1050 input_ids: vec![],
1051 output_ids: vec![],
1052 metadata: HashMap::new(),
1053 };
1054 trace.add_entry(entry);
1055 }
1056
1057 let slowest = trace.slowest_operations(3);
1058 assert_eq!(slowest.len(), 3);
1059 assert_eq!(slowest[0].node_id, 4); assert_eq!(slowest[1].node_id, 3);
1061 assert_eq!(slowest[2].node_id, 2);
1062 }
1063
1064 #[test]
1065 fn test_watch_list() {
1066 let mut inspector = TensorInspector::new();
1067 inspector.enable();
1068
1069 inspector.watch(1);
1070 inspector.watch(2);
1071
1072 assert!(inspector.should_inspect(1));
1073 assert!(inspector.should_inspect(2));
1074 assert!(!inspector.should_inspect(3));
1075
1076 inspector.unwatch(1);
1077 assert!(!inspector.should_inspect(1));
1078 assert!(inspector.should_inspect(2));
1079
1080 inspector.clear_watch_list();
1081 assert!(inspector.should_inspect(5));
1083 }
1084
1085 #[test]
1086 fn test_trace_entries_for_node() {
1087 let mut trace = ExecutionTrace::new();
1088 for i in 0..3 {
1089 let entry = TraceEntry {
1090 entry_id: i,
1091 node_id: i % 2,
1092 operation: "op".to_string(),
1093 start_time: Instant::now(),
1094 duration: Duration::from_millis(10),
1095 input_ids: vec![],
1096 output_ids: vec![],
1097 metadata: HashMap::new(),
1098 };
1099 trace.add_entry(entry);
1100 }
1101
1102 let node_0_entries = trace.entries_for_node(0);
1103 assert_eq!(node_0_entries.len(), 2);
1104
1105 let node_1_entries = trace.entries_for_node(1);
1106 assert_eq!(node_1_entries.len(), 1);
1107 }
1108
1109 #[test]
1110 fn test_critical_path_linear_chain() {
1111 let mut trace = ExecutionTrace::new();
1113
1114 trace.add_entry(TraceEntry {
1116 entry_id: 0,
1117 node_id: 0,
1118 operation: "op0".to_string(),
1119 start_time: Instant::now(),
1120 duration: Duration::from_millis(10),
1121 input_ids: vec![],
1122 output_ids: vec![0],
1123 metadata: HashMap::new(),
1124 });
1125
1126 trace.add_entry(TraceEntry {
1128 entry_id: 1,
1129 node_id: 1,
1130 operation: "op1".to_string(),
1131 start_time: Instant::now(),
1132 duration: Duration::from_millis(20),
1133 input_ids: vec![0],
1134 output_ids: vec![1],
1135 metadata: HashMap::new(),
1136 });
1137
1138 trace.add_entry(TraceEntry {
1140 entry_id: 2,
1141 node_id: 2,
1142 operation: "op2".to_string(),
1143 start_time: Instant::now(),
1144 duration: Duration::from_millis(15),
1145 input_ids: vec![1],
1146 output_ids: vec![2],
1147 metadata: HashMap::new(),
1148 });
1149
1150 let critical_path = trace.critical_path();
1151 assert_eq!(critical_path.len(), 3); assert_eq!(critical_path[0].node_id, 0);
1153 assert_eq!(critical_path[1].node_id, 1);
1154 assert_eq!(critical_path[2].node_id, 2);
1155
1156 let cp_duration = trace.critical_path_duration();
1157 assert_eq!(cp_duration, Duration::from_millis(45)); }
1159
1160 #[test]
1161 fn test_critical_path_parallel_operations() {
1162 let mut trace = ExecutionTrace::new();
1164
1165 trace.add_entry(TraceEntry {
1167 entry_id: 0,
1168 node_id: 0,
1169 operation: "op0".to_string(),
1170 start_time: Instant::now(),
1171 duration: Duration::from_millis(10),
1172 input_ids: vec![],
1173 output_ids: vec![0, 1],
1174 metadata: HashMap::new(),
1175 });
1176
1177 trace.add_entry(TraceEntry {
1179 entry_id: 1,
1180 node_id: 1,
1181 operation: "op1".to_string(),
1182 start_time: Instant::now(),
1183 duration: Duration::from_millis(5),
1184 input_ids: vec![0],
1185 output_ids: vec![2],
1186 metadata: HashMap::new(),
1187 });
1188
1189 trace.add_entry(TraceEntry {
1191 entry_id: 2,
1192 node_id: 2,
1193 operation: "op2".to_string(),
1194 start_time: Instant::now(),
1195 duration: Duration::from_millis(20),
1196 input_ids: vec![1],
1197 output_ids: vec![3],
1198 metadata: HashMap::new(),
1199 });
1200
1201 let critical_path = trace.critical_path();
1202 assert_eq!(critical_path.len(), 2);
1204 assert_eq!(critical_path[0].node_id, 0);
1205 assert_eq!(critical_path[1].node_id, 2);
1206
1207 let cp_duration = trace.critical_path_duration();
1208 assert_eq!(cp_duration, Duration::from_millis(30)); }
1210
1211 #[test]
1212 fn test_critical_path_diamond_pattern() {
1213 let mut trace = ExecutionTrace::new();
1215
1216 trace.add_entry(TraceEntry {
1218 entry_id: 0,
1219 node_id: 0,
1220 operation: "op0".to_string(),
1221 start_time: Instant::now(),
1222 duration: Duration::from_millis(10),
1223 input_ids: vec![],
1224 output_ids: vec![0],
1225 metadata: HashMap::new(),
1226 });
1227
1228 trace.add_entry(TraceEntry {
1230 entry_id: 1,
1231 node_id: 1,
1232 operation: "op1".to_string(),
1233 start_time: Instant::now(),
1234 duration: Duration::from_millis(5),
1235 input_ids: vec![0],
1236 output_ids: vec![1],
1237 metadata: HashMap::new(),
1238 });
1239
1240 trace.add_entry(TraceEntry {
1242 entry_id: 2,
1243 node_id: 2,
1244 operation: "op2".to_string(),
1245 start_time: Instant::now(),
1246 duration: Duration::from_millis(25),
1247 input_ids: vec![0],
1248 output_ids: vec![2],
1249 metadata: HashMap::new(),
1250 });
1251
1252 trace.add_entry(TraceEntry {
1254 entry_id: 3,
1255 node_id: 3,
1256 operation: "op3".to_string(),
1257 start_time: Instant::now(),
1258 duration: Duration::from_millis(15),
1259 input_ids: vec![1, 2],
1260 output_ids: vec![3],
1261 metadata: HashMap::new(),
1262 });
1263
1264 let critical_path = trace.critical_path();
1265 assert_eq!(critical_path.len(), 3);
1267 assert_eq!(critical_path[0].node_id, 0);
1268 assert_eq!(critical_path[1].node_id, 2);
1269 assert_eq!(critical_path[2].node_id, 3);
1270
1271 let cp_duration = trace.critical_path_duration();
1272 assert_eq!(cp_duration, Duration::from_millis(50)); }
1274
1275 #[test]
1276 fn test_critical_path_empty() {
1277 let trace = ExecutionTrace::new();
1278 let critical_path = trace.critical_path();
1279 assert_eq!(critical_path.len(), 0);
1280 assert_eq!(trace.critical_path_duration(), Duration::ZERO);
1281 }
1282
1283 #[test]
1284 fn test_critical_path_single_operation() {
1285 let mut trace = ExecutionTrace::new();
1286 trace.add_entry(TraceEntry {
1287 entry_id: 0,
1288 node_id: 0,
1289 operation: "op0".to_string(),
1290 start_time: Instant::now(),
1291 duration: Duration::from_millis(10),
1292 input_ids: vec![],
1293 output_ids: vec![0],
1294 metadata: HashMap::new(),
1295 });
1296
1297 let critical_path = trace.critical_path();
1298 assert_eq!(critical_path.len(), 1);
1299 assert_eq!(critical_path[0].node_id, 0);
1300 }
1301
1302 #[test]
1303 fn test_parallelism_factor() {
1304 let mut trace = ExecutionTrace::new();
1305
1306 trace.add_entry(TraceEntry {
1313 entry_id: 0,
1314 node_id: 0,
1315 operation: "op0".to_string(),
1316 start_time: Instant::now(),
1317 duration: Duration::from_millis(10),
1318 input_ids: vec![],
1319 output_ids: vec![0],
1320 metadata: HashMap::new(),
1321 });
1322
1323 trace.add_entry(TraceEntry {
1325 entry_id: 1,
1326 node_id: 1,
1327 operation: "op1".to_string(),
1328 start_time: Instant::now(),
1329 duration: Duration::from_millis(20),
1330 input_ids: vec![0],
1331 output_ids: vec![1],
1332 metadata: HashMap::new(),
1333 });
1334
1335 trace.add_entry(TraceEntry {
1337 entry_id: 2,
1338 node_id: 2,
1339 operation: "op2".to_string(),
1340 start_time: Instant::now(),
1341 duration: Duration::from_millis(30),
1342 input_ids: vec![0],
1343 output_ids: vec![2],
1344 metadata: HashMap::new(),
1345 });
1346
1347 trace.add_entry(TraceEntry {
1349 entry_id: 3,
1350 node_id: 3,
1351 operation: "op3".to_string(),
1352 start_time: Instant::now(),
1353 duration: Duration::from_millis(40),
1354 input_ids: vec![1, 2],
1355 output_ids: vec![3],
1356 metadata: HashMap::new(),
1357 });
1358
1359 let parallelism = trace.parallelism_factor();
1360 assert!((parallelism - 1.25).abs() < 0.01);
1364 }
1365
1366 #[test]
1367 fn test_critical_path_complex_graph() {
1368 let mut trace = ExecutionTrace::new();
1370
1371 trace.add_entry(TraceEntry {
1373 entry_id: 0,
1374 node_id: 0,
1375 operation: "root".to_string(),
1376 start_time: Instant::now(),
1377 duration: Duration::from_millis(5),
1378 input_ids: vec![],
1379 output_ids: vec![0],
1380 metadata: HashMap::new(),
1381 });
1382
1383 for i in 1..=3 {
1385 trace.add_entry(TraceEntry {
1386 entry_id: i,
1387 node_id: i,
1388 operation: format!("branch{}", i),
1389 start_time: Instant::now(),
1390 duration: Duration::from_millis((i as u64) * 10),
1391 input_ids: vec![0],
1392 output_ids: vec![i],
1393 metadata: HashMap::new(),
1394 });
1395 }
1396
1397 trace.add_entry(TraceEntry {
1399 entry_id: 4,
1400 node_id: 4,
1401 operation: "merge".to_string(),
1402 start_time: Instant::now(),
1403 duration: Duration::from_millis(15),
1404 input_ids: vec![2, 3],
1405 output_ids: vec![4],
1406 metadata: HashMap::new(),
1407 });
1408
1409 let critical_path = trace.critical_path();
1410 assert!(critical_path.len() >= 3);
1412 assert_eq!(critical_path[0].operation, "root");
1413 assert_eq!(critical_path[critical_path.len() - 1].operation, "merge");
1415 }
1416}