1use std::collections::{HashMap, VecDeque};
44use std::sync::{Arc, Mutex, OnceLock};
45use std::time::{Duration, Instant};
46
47use crate::dtype::DType;
48use crate::runtime_config::RuntimeConfig;
49
50static OP_TRACER: OnceLock<Arc<Mutex<OpTracerInternal>>> = OnceLock::new();
52
53pub type TraceId = u64;
55
56#[derive(Debug, Clone)]
58pub struct TraceConfig {
59 pub enabled: bool,
61 pub max_traces: usize,
63 pub capture_values: bool,
65 pub capture_outputs: bool,
67 pub capture_stack_trace: bool,
69 pub operation_filters: Vec<String>,
71 pub max_depth: usize,
73 pub break_on_error: bool,
75}
76
77impl Default for TraceConfig {
78 fn default() -> Self {
79 Self {
80 enabled: false,
81 max_traces: 10_000,
82 capture_values: false,
83 capture_outputs: false,
84 capture_stack_trace: false,
85 operation_filters: Vec::new(),
86 max_depth: 0, break_on_error: true,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct TensorMetadata {
95 pub name: String,
97 pub shape: Vec<usize>,
99 pub dtype: Option<DType>,
101 pub numel: usize,
103 pub size_bytes: usize,
105 pub is_contiguous: bool,
107 pub values: Option<Vec<f64>>,
109}
110
111impl TensorMetadata {
112 pub fn new(name: impl Into<String>, shape: Vec<usize>) -> Self {
114 let numel = shape.iter().product();
115 Self {
116 name: name.into(),
117 shape,
118 dtype: None,
119 numel,
120 size_bytes: 0,
121 is_contiguous: true,
122 values: None,
123 }
124 }
125
126 pub fn with_dtype(mut self, dtype: DType) -> Self {
128 self.size_bytes = self.numel * dtype.size();
129 self.dtype = Some(dtype);
130 self
131 }
132
133 pub fn with_contiguous(mut self, is_contiguous: bool) -> Self {
135 self.is_contiguous = is_contiguous;
136 self
137 }
138
139 pub fn with_values(mut self, values: Vec<f64>) -> Self {
141 self.values = Some(values);
142 self
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct OperationTrace {
149 pub id: TraceId,
151 pub parent_id: Option<TraceId>,
153 pub operation: String,
155 pub category: Option<String>,
157 pub inputs: Vec<TensorMetadata>,
159 pub outputs: Vec<TensorMetadata>,
161 pub start_time: Instant,
163 pub duration: Option<Duration>,
165 pub depth: usize,
167 pub stack_trace: Option<String>,
169 pub metadata: HashMap<String, String>,
171 pub had_error: bool,
173 pub error_message: Option<String>,
175}
176
177impl OperationTrace {
178 fn new(id: TraceId, parent_id: Option<TraceId>, operation: String, depth: usize) -> Self {
180 Self {
181 id,
182 parent_id,
183 operation,
184 category: None,
185 inputs: Vec::new(),
186 outputs: Vec::new(),
187 start_time: Instant::now(),
188 duration: None,
189 depth,
190 stack_trace: None,
191 metadata: HashMap::new(),
192 had_error: false,
193 error_message: None,
194 }
195 }
196
197 pub fn set_category(&mut self, category: impl Into<String>) {
199 self.category = Some(category.into());
200 }
201
202 pub fn add_input(&mut self, input: TensorMetadata) {
204 self.inputs.push(input);
205 }
206
207 pub fn add_output(&mut self, output: TensorMetadata) {
209 self.outputs.push(output);
210 }
211
212 pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
214 self.metadata.insert(key.into(), value.into());
215 }
216
217 pub fn complete(&mut self) {
219 self.duration = Some(self.start_time.elapsed());
220 }
221
222 pub fn mark_error(&mut self, error: impl Into<String>) {
224 self.had_error = true;
225 self.error_message = Some(error.into());
226 self.complete();
227 }
228
229 fn matches_filter(&self, filter: &str) -> bool {
231 self.operation.contains(filter)
232 || self.category.as_ref().map_or(false, |c| c.contains(filter))
233 }
234}
235
236pub struct TraceBuilder {
238 trace_id: TraceId,
239}
240
241impl TraceBuilder {
242 fn new(trace_id: TraceId) -> Self {
243 Self { trace_id }
244 }
245
246 pub fn record_input(&self, name: impl Into<String>, shape: Vec<usize>) {
248 let metadata = TensorMetadata::new(name, shape);
249 if let Some(tracer) = OP_TRACER.get() {
250 if let Ok(mut tracer) = tracer.lock() {
251 if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
252 trace.add_input(metadata);
253 }
254 }
255 }
256 }
257
258 pub fn record_input_with_dtype(
260 &self,
261 name: impl Into<String>,
262 shape: Vec<usize>,
263 dtype: DType,
264 ) {
265 let metadata = TensorMetadata::new(name, shape).with_dtype(dtype);
266 if let Some(tracer) = OP_TRACER.get() {
267 if let Ok(mut tracer) = tracer.lock() {
268 if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
269 trace.add_input(metadata);
270 }
271 }
272 }
273 }
274
275 pub fn record_output(&self, name: impl Into<String>, shape: Vec<usize>) {
277 let metadata = TensorMetadata::new(name, shape);
278 if let Some(tracer) = OP_TRACER.get() {
279 if let Ok(mut tracer) = tracer.lock() {
280 if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
281 trace.add_output(metadata);
282 }
283 }
284 }
285 }
286
287 pub fn record_output_with_dtype(
289 &self,
290 name: impl Into<String>,
291 shape: Vec<usize>,
292 dtype: DType,
293 ) {
294 let metadata = TensorMetadata::new(name, shape).with_dtype(dtype);
295 if let Some(tracer) = OP_TRACER.get() {
296 if let Ok(mut tracer) = tracer.lock() {
297 if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
298 trace.add_output(metadata);
299 }
300 }
301 }
302 }
303
304 pub fn add_metadata(&self, key: impl Into<String>, value: impl Into<String>) {
306 if let Some(tracer) = OP_TRACER.get() {
307 if let Ok(mut tracer) = tracer.lock() {
308 if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
309 trace.add_metadata(key, value);
310 }
311 }
312 }
313 }
314
315 pub fn set_category(&self, category: impl Into<String>) {
317 if let Some(tracer) = OP_TRACER.get() {
318 if let Ok(mut tracer) = tracer.lock() {
319 if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
320 trace.set_category(category);
321 }
322 }
323 }
324 }
325}
326
327struct OpTracerInternal {
329 config: TraceConfig,
330 traces: HashMap<TraceId, OperationTrace>,
331 trace_order: VecDeque<TraceId>,
332 next_id: TraceId,
333 current_depth: usize,
334 depth_stack: Vec<TraceId>,
335 breakpoints: HashMap<String, bool>, }
337
338impl OpTracerInternal {
339 fn new() -> Self {
340 Self {
341 config: TraceConfig::default(),
342 traces: HashMap::new(),
343 trace_order: VecDeque::new(),
344 next_id: 1,
345 current_depth: 0,
346 depth_stack: Vec::new(),
347 breakpoints: HashMap::new(),
348 }
349 }
350
351 fn should_trace(&self, operation: &str) -> bool {
352 if !self.config.enabled {
353 return false;
354 }
355
356 if self.config.max_depth > 0 && self.current_depth >= self.config.max_depth {
358 return false;
359 }
360
361 if !self.config.operation_filters.is_empty() {
363 return self
364 .config
365 .operation_filters
366 .iter()
367 .any(|f| operation.contains(f));
368 }
369
370 true
371 }
372
373 fn start_trace(&mut self, operation: String) -> Option<TraceId> {
374 if !self.should_trace(&operation) {
375 return None;
376 }
377
378 let trace_id = self.next_id;
379 self.next_id += 1;
380
381 let parent_id = self.depth_stack.last().copied();
382 let trace = OperationTrace::new(trace_id, parent_id, operation, self.current_depth);
383
384 self.traces.insert(trace_id, trace);
385 self.trace_order.push_back(trace_id);
386 self.depth_stack.push(trace_id);
387 self.current_depth += 1;
388
389 while self.trace_order.len() > self.config.max_traces {
391 if let Some(old_id) = self.trace_order.pop_front() {
392 self.traces.remove(&old_id);
393 }
394 }
395
396 Some(trace_id)
397 }
398
399 fn complete_trace(&mut self, trace_id: TraceId) {
400 if let Some(trace) = self.traces.get_mut(&trace_id) {
401 trace.complete();
402 }
403
404 if self.depth_stack.last() == Some(&trace_id) {
406 self.depth_stack.pop();
407 if self.current_depth > 0 {
408 self.current_depth -= 1;
409 }
410 }
411 }
412
413 fn mark_error(&mut self, trace_id: TraceId, error: String) {
414 if let Some(trace) = self.traces.get_mut(&trace_id) {
415 trace.mark_error(error);
416 }
417 }
418}
419
420pub struct OpTracer {
422 inner: Arc<Mutex<OpTracerInternal>>,
423}
424
425impl OpTracer {
426 pub fn global() -> Self {
428 let inner = OP_TRACER
429 .get_or_init(|| Arc::new(Mutex::new(OpTracerInternal::new())))
430 .clone();
431 Self { inner }
432 }
433
434 pub fn new() -> Self {
436 Self {
437 inner: Arc::new(Mutex::new(OpTracerInternal::new())),
438 }
439 }
440
441 pub fn set_enabled(&self, enabled: bool) {
443 if let Ok(mut tracer) = self.inner.lock() {
444 tracer.config.enabled = enabled;
445 }
446 }
447
448 pub fn is_enabled(&self) -> bool {
450 self.inner.lock().map_or(false, |t| t.config.enabled)
451 }
452
453 pub fn set_config(&self, config: TraceConfig) {
455 if let Ok(mut tracer) = self.inner.lock() {
456 tracer.config = config;
457 }
458 }
459
460 pub fn get_config(&self) -> TraceConfig {
462 self.inner
463 .lock()
464 .map_or(TraceConfig::default(), |t| t.config.clone())
465 }
466
467 pub fn add_filter(&self, pattern: impl Into<String>) {
469 if let Ok(mut tracer) = self.inner.lock() {
470 tracer.config.operation_filters.push(pattern.into());
471 }
472 }
473
474 pub fn clear_filters(&self) {
476 if let Ok(mut tracer) = self.inner.lock() {
477 tracer.config.operation_filters.clear();
478 }
479 }
480
481 pub fn set_breakpoint(&self, operation: impl Into<String>) {
483 if let Ok(mut tracer) = self.inner.lock() {
484 tracer.breakpoints.insert(operation.into(), true);
485 }
486 }
487
488 pub fn remove_breakpoint(&self, operation: &str) {
490 if let Ok(mut tracer) = self.inner.lock() {
491 tracer.breakpoints.remove(operation);
492 }
493 }
494
495 pub fn has_breakpoint(&self, operation: &str) -> bool {
497 self.inner.lock().map_or(false, |t| {
498 t.breakpoints.get(operation).copied().unwrap_or(false)
499 })
500 }
501
502 pub fn get_trace(&self, trace_id: TraceId) -> Option<OperationTrace> {
504 self.inner.lock().ok()?.traces.get(&trace_id).cloned()
505 }
506
507 pub fn get_all_traces(&self) -> Vec<OperationTrace> {
509 self.inner.lock().map_or(Vec::new(), |t| {
510 t.trace_order
511 .iter()
512 .filter_map(|id| t.traces.get(id).cloned())
513 .collect()
514 })
515 }
516
517 pub fn get_filtered_traces(&self, filter: &str) -> Vec<OperationTrace> {
519 self.inner.lock().map_or(Vec::new(), |t| {
520 t.trace_order
521 .iter()
522 .filter_map(|id| t.traces.get(id))
523 .filter(|trace| trace.matches_filter(filter))
524 .cloned()
525 .collect()
526 })
527 }
528
529 pub fn clear_traces(&self) {
531 if let Ok(mut tracer) = self.inner.lock() {
532 tracer.traces.clear();
533 tracer.trace_order.clear();
534 }
535 }
536
537 pub fn get_statistics(&self) -> TraceStatistics {
539 let tracer = match self.inner.lock() {
540 Ok(t) => t,
541 Err(_) => return TraceStatistics::default(),
542 };
543
544 let total_traces = tracer.traces.len();
545 let total_errors = tracer.traces.values().filter(|t| t.had_error).count();
546
547 let total_duration: Duration = tracer.traces.values().filter_map(|t| t.duration).sum();
548
549 let operations_by_type: HashMap<String, usize> =
550 tracer
551 .traces
552 .values()
553 .fold(HashMap::new(), |mut acc, trace| {
554 *acc.entry(trace.operation.clone()).or_insert(0) += 1;
555 acc
556 });
557
558 TraceStatistics {
559 total_traces,
560 total_errors,
561 total_duration,
562 operations_by_type,
563 }
564 }
565}
566
567impl Default for OpTracer {
568 fn default() -> Self {
569 Self::global()
570 }
571}
572
573#[derive(Debug, Clone, Default)]
575pub struct TraceStatistics {
576 pub total_traces: usize,
577 pub total_errors: usize,
578 pub total_duration: Duration,
579 pub operations_by_type: HashMap<String, usize>,
580}
581
582pub fn trace_operation<F>(operation: impl Into<String>, f: F) -> Option<TraceId>
596where
597 F: FnOnce(&TraceBuilder),
598{
599 let operation = operation.into();
600 let tracer = OpTracer::global();
601
602 let runtime_config = RuntimeConfig::global();
604 if !runtime_config.should_collect_metrics(&operation) {
605 return None;
606 }
607
608 let trace_id = {
609 let mut inner = tracer.inner.lock().ok()?;
610 inner.start_trace(operation.clone())?
611 };
612
613 let builder = TraceBuilder::new(trace_id);
614 f(&builder);
615
616 {
617 let mut inner = tracer.inner.lock().ok()?;
618 inner.complete_trace(trace_id);
619 }
620
621 Some(trace_id)
622}
623
624pub fn trace_operation_result<F, T, E>(operation: impl Into<String>, f: F) -> Result<T, E>
626where
627 F: FnOnce(&TraceBuilder) -> Result<T, E>,
628 E: std::fmt::Display,
629{
630 let operation = operation.into();
631 let tracer = OpTracer::global();
632
633 let trace_id = {
634 let mut inner = tracer.inner.lock().ok().ok_or_else(|| {
635 panic!("Failed to acquire tracer lock")
637 })?;
638 inner.start_trace(operation.clone())
639 };
640
641 let builder = trace_id.map(TraceBuilder::new);
642
643 let result = match builder.as_ref() {
644 Some(b) => f(b),
645 None => f(&TraceBuilder::new(0)), };
647
648 if let Some(tid) = trace_id {
649 let mut inner = tracer
650 .inner
651 .lock()
652 .ok()
653 .ok_or_else(|| panic!("Failed to acquire tracer lock"))?;
654
655 match &result {
656 Ok(_) => inner.complete_trace(tid),
657 Err(e) => inner.mark_error(tid, e.to_string()),
658 }
659 }
660
661 result
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[test]
669 fn test_tracer_enable_disable() {
670 let tracer = OpTracer::new();
671 assert!(!tracer.is_enabled());
672
673 tracer.set_enabled(true);
674 assert!(tracer.is_enabled());
675
676 tracer.set_enabled(false);
677 assert!(!tracer.is_enabled());
678 }
679
680 #[test]
681 fn test_trace_operation() {
682 let tracer = OpTracer::new();
683 tracer.set_enabled(true);
684
685 let trace_id = {
687 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
688 inner
689 .start_trace("test_op".to_string())
690 .expect("start_trace should succeed")
691 };
692
693 assert!(tracer.get_trace(trace_id).is_some());
694
695 {
696 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
697 inner.complete_trace(trace_id);
698 }
699
700 let trace = tracer.get_trace(trace_id).expect("trace should exist");
701 assert_eq!(trace.operation, "test_op");
702 assert!(trace.duration.is_some());
703 }
704
705 #[test]
706 fn test_trace_with_inputs_outputs() {
707 let tracer = OpTracer::new();
708 tracer.set_enabled(true);
709
710 let trace_id = {
711 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
712 inner
713 .start_trace("matmul".to_string())
714 .expect("start_trace should succeed")
715 };
716
717 {
719 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
720 if let Some(trace) = inner.traces.get_mut(&trace_id) {
721 trace.add_input(TensorMetadata::new("lhs", vec![10, 20]).with_dtype(DType::F32));
722 trace.add_input(TensorMetadata::new("rhs", vec![20, 30]).with_dtype(DType::F32));
723 trace
724 .add_output(TensorMetadata::new("result", vec![10, 30]).with_dtype(DType::F32));
725 }
726 }
727
728 {
729 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
730 inner.complete_trace(trace_id);
731 }
732
733 let trace = tracer.get_trace(trace_id).expect("trace should exist");
734 assert_eq!(trace.inputs.len(), 2);
735 assert_eq!(trace.outputs.len(), 1);
736 assert_eq!(trace.inputs[0].shape, vec![10, 20]);
737 assert_eq!(trace.outputs[0].shape, vec![10, 30]);
738 }
739
740 #[test]
741 fn test_trace_filtering() {
742 let tracer = OpTracer::new();
743 tracer.set_enabled(true);
744 tracer.add_filter("matmul");
745
746 let trace_id1 = {
748 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
749 inner.start_trace("matmul".to_string())
750 };
751 assert!(trace_id1.is_some());
752
753 let trace_id2 = {
755 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
756 inner.start_trace("add".to_string())
757 };
758 assert!(trace_id2.is_none());
759 }
760
761 #[test]
762 fn test_trace_hierarchy() {
763 let tracer = OpTracer::new();
764 tracer.set_enabled(true);
765
766 let parent_id = {
767 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
768 inner
769 .start_trace("parent_op".to_string())
770 .expect("start_trace should succeed")
771 };
772
773 let child_id = {
774 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
775 inner
776 .start_trace("child_op".to_string())
777 .expect("start_trace should succeed")
778 };
779
780 {
781 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
782 inner.complete_trace(child_id);
783 inner.complete_trace(parent_id);
784 }
785
786 let parent_trace = tracer
787 .get_trace(parent_id)
788 .expect("parent trace should exist");
789 let child_trace = tracer
790 .get_trace(child_id)
791 .expect("child trace should exist");
792
793 assert_eq!(parent_trace.depth, 0);
794 assert_eq!(child_trace.depth, 1);
795 assert_eq!(child_trace.parent_id, Some(parent_id));
796 }
797
798 #[test]
799 fn test_breakpoints() {
800 let tracer = OpTracer::new();
801
802 tracer.set_breakpoint("critical_op");
803 assert!(tracer.has_breakpoint("critical_op"));
804
805 tracer.remove_breakpoint("critical_op");
806 assert!(!tracer.has_breakpoint("critical_op"));
807 }
808
809 #[test]
810 fn test_trace_statistics() {
811 let tracer = OpTracer::new();
812 tracer.set_enabled(true);
813
814 for i in 0..5 {
816 let trace_id = {
817 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
818 inner
819 .start_trace(format!("op_{}", i))
820 .expect("start_trace should succeed")
821 };
822
823 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
824 inner.complete_trace(trace_id);
825 }
826
827 let stats = tracer.get_statistics();
828 assert_eq!(stats.total_traces, 5);
829 assert_eq!(stats.total_errors, 0);
830 }
831
832 #[test]
833 fn test_error_tracing() {
834 let tracer = OpTracer::new();
835 tracer.set_enabled(true);
836
837 let trace_id = {
838 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
839 inner
840 .start_trace("failing_op".to_string())
841 .expect("start_trace should succeed")
842 };
843
844 {
845 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
846 inner.mark_error(trace_id, "Test error".to_string());
847 }
848
849 let trace = tracer.get_trace(trace_id).expect("trace should exist");
850 assert!(trace.had_error);
851 assert_eq!(trace.error_message, Some("Test error".to_string()));
852
853 let stats = tracer.get_statistics();
854 assert_eq!(stats.total_errors, 1);
855 }
856
857 #[test]
858 fn test_max_traces_limit() {
859 let tracer = OpTracer::new();
860 let mut config = TraceConfig::default();
861 config.enabled = true;
862 config.max_traces = 5;
863 tracer.set_config(config);
864
865 for i in 0..10 {
867 let trace_id = {
868 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
869 inner
870 .start_trace(format!("op_{}", i))
871 .expect("start_trace should succeed")
872 };
873
874 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
875 inner.complete_trace(trace_id);
876 }
877
878 let all_traces = tracer.get_all_traces();
879 assert_eq!(all_traces.len(), 5); }
881
882 #[test]
883 fn test_clear_traces() {
884 let tracer = OpTracer::new();
885 tracer.set_enabled(true);
886
887 let trace_id = {
888 let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
889 inner
890 .start_trace("test_op".to_string())
891 .expect("start_trace should succeed")
892 };
893
894 assert!(tracer.get_trace(trace_id).is_some());
895
896 tracer.clear_traces();
897 assert!(tracer.get_trace(trace_id).is_none());
898 }
899}