1use crate::Tensor;
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4
5#[derive(Debug, Clone)]
8pub enum NodeOp {
9 Input,
10 Constant,
11 Add,
12 Sub,
13 MatMul,
14 Conv2d {
15 stride: (usize, usize),
16 padding: (usize, usize),
17 },
18 ReLU,
19 MaxPool2d {
20 kernel_size: (usize, usize),
21 stride: (usize, usize),
22 padding: (usize, usize),
23 },
24 BatchNorm2d,
25 }
27
28#[derive(Debug, Clone)]
29pub struct Node {
30 pub id: usize,
31 pub op: NodeOp,
32 pub inputs: Vec<usize>, pub shape: Vec<usize>,
34 pub name: Option<String>,
35}
36
37#[derive(Debug, Clone)]
38pub struct Graph {
39 pub nodes: Vec<Node>,
40 pub inputs: Vec<usize>,
41 pub outputs: Vec<usize>,
42}
43
44impl Graph {
45 pub fn new() -> Self {
46 Self {
47 nodes: Vec::new(),
48 inputs: Vec::new(),
49 outputs: Vec::new(),
50 }
51 }
52
53 pub fn add_node(
54 &mut self,
55 op: NodeOp,
56 inputs: Vec<usize>,
57 shape: Vec<usize>,
58 name: Option<String>,
59 ) -> usize {
60 let id = self.nodes.len();
61 self.nodes.push(Node {
62 id,
63 op,
64 inputs,
65 shape,
66 name,
67 });
68 id
69 }
70
71 pub fn print(&self) {
72 println!("Graph:");
73 for node in &self.nodes {
74 println!(
75 " Node {}: {:?} shape={:?} inputs={:?}",
76 node.id, node.op, node.shape, node.inputs
77 );
78 }
79 }
80}
81
82pub struct TracerContext {
84 pub graph: Graph,
85 pub tensor_map: HashMap<usize, usize>, }
87
88thread_local! {
89 static TRACER_CTX: Mutex<Option<TracerContext>> = Mutex::new(None);
90}
91
92pub fn start_tracing() {
93 TRACER_CTX.with(|ctx| {
94 *ctx.lock().unwrap() = Some(TracerContext {
95 graph: Graph::new(),
96 tensor_map: HashMap::new(),
97 });
98 });
99}
100
101pub fn stop_tracing() -> Option<Graph> {
102 TRACER_CTX.with(|ctx| {
103 let mut guard = ctx.lock().unwrap();
104 guard.take().map(|c| c.graph)
105 })
106}
107
108pub fn is_tracing() -> bool {
109 TRACER_CTX.with(|ctx| ctx.lock().unwrap().is_some())
110}
111
112fn get_node_id(tensor: &Tensor) -> Option<usize> {
113 TRACER_CTX.with(|ctx| {
114 if let Some(c) = ctx.lock().unwrap().as_ref() {
115 let ptr = Arc::as_ptr(&tensor.inner) as usize;
116 c.tensor_map.get(&ptr).cloned()
117 } else {
118 None
119 }
120 })
121}
122
123pub fn register_input(tensor: &Tensor, name: String) {
124 TRACER_CTX.with(|ctx| {
125 if let Some(c) = ctx.lock().unwrap().as_mut() {
126 let node_id =
127 c.graph
128 .add_node(NodeOp::Input, vec![], tensor.shape().to_vec(), Some(name));
129 let ptr = Arc::as_ptr(&tensor.inner) as usize;
130 c.tensor_map.insert(ptr, node_id);
131 c.graph.inputs.push(node_id);
132 }
133 });
134}
135
136pub fn record_op(op: NodeOp, inputs: &[&Tensor], output: &Tensor) {
137 let mut input_ids = Vec::new();
139 for t in inputs {
140 if let Some(id) = get_node_id(t) {
141 input_ids.push(id);
142 } else {
143 TRACER_CTX.with(|ctx| {
146 if let Some(c) = ctx.lock().unwrap().as_mut() {
147 let id = c
148 .graph
149 .add_node(NodeOp::Constant, vec![], t.shape().to_vec(), None);
150 let ptr = Arc::as_ptr(&t.inner) as usize;
151 c.tensor_map.insert(ptr, id);
152 input_ids.push(id);
153 }
154 });
155 }
156 }
157
158 TRACER_CTX.with(|ctx| {
159 if let Some(c) = ctx.lock().unwrap().as_mut() {
160 let node_id = c
161 .graph
162 .add_node(op, input_ids, output.shape().to_vec(), None);
163 let ptr = Arc::as_ptr(&output.inner) as usize;
164 c.tensor_map.insert(ptr, node_id);
165 }
166 });
167}