Skip to main content

rustorch_core/
graph.rs

1use crate::Tensor;
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4
5// --- Graph Tracing ---
6
7#[derive(Debug, Clone)]
8pub enum NodeOp {
9    Input,
10    Constant,
11    Add,
12    Sub,
13    MatMul,
14    Conv2d {
15        stride: (usize, usize),
16        padding: (usize, usize),
17    },
18    ReLU,
19    MaxPool2d {
20        kernel_size: (usize, usize),
21        stride: (usize, usize),
22        padding: (usize, usize),
23    },
24    BatchNorm2d,
25    // ... other ops
26}
27
28#[derive(Debug, Clone)]
29pub struct Node {
30    pub id: usize,
31    pub op: NodeOp,
32    pub inputs: Vec<usize>, // Input Node IDs
33    pub shape: Vec<usize>,
34    pub name: Option<String>,
35}
36
37#[derive(Debug, Clone)]
38pub struct Graph {
39    pub nodes: Vec<Node>,
40    pub inputs: Vec<usize>,
41    pub outputs: Vec<usize>,
42}
43
44impl Graph {
45    pub fn new() -> Self {
46        Self {
47            nodes: Vec::new(),
48            inputs: Vec::new(),
49            outputs: Vec::new(),
50        }
51    }
52
53    pub fn add_node(
54        &mut self,
55        op: NodeOp,
56        inputs: Vec<usize>,
57        shape: Vec<usize>,
58        name: Option<String>,
59    ) -> usize {
60        let id = self.nodes.len();
61        self.nodes.push(Node {
62            id,
63            op,
64            inputs,
65            shape,
66            name,
67        });
68        id
69    }
70
71    pub fn print(&self) {
72        println!("Graph:");
73        for node in &self.nodes {
74            println!(
75                "  Node {}: {:?} shape={:?} inputs={:?}",
76                node.id, node.op, node.shape, node.inputs
77            );
78        }
79    }
80}
81
82// Global Tracer Context (Thread Local)
83pub struct TracerContext {
84    pub graph: Graph,
85    pub tensor_map: HashMap<usize, usize>, // Tensor Inner Ptr -> Node ID
86}
87
88thread_local! {
89    static TRACER_CTX: Mutex<Option<TracerContext>> = Mutex::new(None);
90}
91
92pub fn start_tracing() {
93    TRACER_CTX.with(|ctx| {
94        *ctx.lock().unwrap() = Some(TracerContext {
95            graph: Graph::new(),
96            tensor_map: HashMap::new(),
97        });
98    });
99}
100
101pub fn stop_tracing() -> Option<Graph> {
102    TRACER_CTX.with(|ctx| {
103        let mut guard = ctx.lock().unwrap();
104        guard.take().map(|c| c.graph)
105    })
106}
107
108pub fn is_tracing() -> bool {
109    TRACER_CTX.with(|ctx| ctx.lock().unwrap().is_some())
110}
111
112fn get_node_id(tensor: &Tensor) -> Option<usize> {
113    TRACER_CTX.with(|ctx| {
114        if let Some(c) = ctx.lock().unwrap().as_ref() {
115            let ptr = Arc::as_ptr(&tensor.inner) as usize;
116            c.tensor_map.get(&ptr).cloned()
117        } else {
118            None
119        }
120    })
121}
122
123pub fn register_input(tensor: &Tensor, name: String) {
124    TRACER_CTX.with(|ctx| {
125        if let Some(c) = ctx.lock().unwrap().as_mut() {
126            let node_id =
127                c.graph
128                    .add_node(NodeOp::Input, vec![], tensor.shape().to_vec(), Some(name));
129            let ptr = Arc::as_ptr(&tensor.inner) as usize;
130            c.tensor_map.insert(ptr, node_id);
131            c.graph.inputs.push(node_id);
132        }
133    });
134}
135
136pub fn record_op(op: NodeOp, inputs: &[&Tensor], output: &Tensor) {
137    // Collect input IDs
138    let mut input_ids = Vec::new();
139    for t in inputs {
140        if let Some(id) = get_node_id(t) {
141            input_ids.push(id);
142        } else {
143            // If input is not tracked, register as constant?
144            // For now, let's assume it's a constant/param
145            TRACER_CTX.with(|ctx| {
146                if let Some(c) = ctx.lock().unwrap().as_mut() {
147                    let id = c
148                        .graph
149                        .add_node(NodeOp::Constant, vec![], t.shape().to_vec(), None);
150                    let ptr = Arc::as_ptr(&t.inner) as usize;
151                    c.tensor_map.insert(ptr, id);
152                    input_ids.push(id);
153                }
154            });
155        }
156    }
157
158    TRACER_CTX.with(|ctx| {
159        if let Some(c) = ctx.lock().unwrap().as_mut() {
160            let node_id = c
161                .graph
162                .add_node(op, input_ids, output.shape().to_vec(), None);
163            let ptr = Arc::as_ptr(&output.inner) as usize;
164            c.tensor_map.insert(ptr, node_id);
165        }
166    });
167}