Skip to main content

rustorch_core/
jit.rs

1use crate::Tensor;
2use std::collections::HashMap;
3
4// --- IR Definition ---
5
6#[derive(Clone, Debug, PartialEq)]
7pub enum NodeType {
8    Input(usize),   // Input index
9    Weight(Tensor), // Captured weight (constant)
10
11    // Ops
12    Add(usize, usize), // LHS, RHS node indices
13    Mul(usize, usize),
14    MatMul(usize, usize),
15    Relu(usize),
16    Conv2d(usize, usize, (usize, usize), (usize, usize)), // Input, Weight, Stride, Padding
17
18    // Fused Ops
19    Conv2dRelu(usize, usize, (usize, usize), (usize, usize)),
20    LinearRelu(usize, usize, usize), // Input, Weight, Bias (Optional?)
21}
22
23#[derive(Debug)]
24pub struct Node {
25    pub op: NodeType,
26    pub shape: Vec<usize>,
27    pub id: usize,
28    // dependencies, users, etc.
29}
30
31#[derive(Debug)]
32pub struct Graph {
33    pub nodes: Vec<Node>,
34    pub inputs: Vec<usize>,
35    pub outputs: Vec<usize>,
36}
37
38impl Default for Graph {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl Graph {
45    pub fn new() -> Self {
46        Self {
47            nodes: Vec::new(),
48            inputs: Vec::new(),
49            outputs: Vec::new(),
50        }
51    }
52
53    pub fn add_node(&mut self, op: NodeType, shape: Vec<usize>) -> usize {
54        let id = self.nodes.len();
55        self.nodes.push(Node { op, shape, id });
56        id
57    }
58
59    pub fn add_input(&mut self, shape: Vec<usize>) -> usize {
60        let id = self.add_node(NodeType::Input(self.inputs.len()), shape);
61        self.inputs.push(id);
62        id
63    }
64
65    pub fn add_weight(&mut self, tensor: Tensor) -> usize {
66        self.add_node(NodeType::Weight(tensor.clone()), tensor.shape().to_vec())
67    }
68}
69
70// --- Tracer ---
71// A simple tracer that records operations.
72// In a real framework, we would use a thread-local graph context or proxy tensors.
73// Here we simulate tracing by manually building graph or using a "TracedTensor" wrapper.
74
75// Let's implement a simple "Optimizer" pass.
76
77pub struct Optimizer;
78
79impl Optimizer {
80    pub fn optimize(graph: &mut Graph) {
81        Self::fuse_conv_relu(graph);
82        // Self::eliminate_dead_code(graph);
83    }
84
85    fn fuse_conv_relu(graph: &mut Graph) {
86        // Look for Conv2d -> Relu pattern
87        // This requires analyzing graph topology.
88        // For simplicity: Iterate nodes, if Relu(Conv2d(idx)), replace op.
89
90        // We can't easily modify Vec while iterating.
91        // And we need to redirect edges.
92        // Simplified approach: Build new graph.
93
94        let mut new_nodes = Vec::new();
95        let mut mapping = HashMap::new(); // Old ID -> New ID
96
97        // We iterate old nodes.
98        // If we see Conv2d, we look ahead? No, usually we look at Relu and check input.
99
100        // But to rebuild, we visit in topological order (which is index order here).
101
102        let n = graph.nodes.len();
103        let mut consumed = vec![false; n];
104
105        for i in 0..n {
106            if consumed[i] {
107                continue;
108            }
109
110            let node = &graph.nodes[i];
111
112            match &node.op {
113                NodeType::Conv2d(input_id, weight_id, stride, padding) => {
114                    // Check if this node is used ONLY by a Relu
115                    // If so, we can fuse.
116                    // We need use-def chains.
117                    // For this demo, let's peek ahead.
118                    // If next node is Relu and takes this Conv2d as input, fuse.
119                    // (This assumes linear ordering which is not guaranteed but common in sequential models)
120
121                    let mut fused = false;
122                    // Find if any future node is Relu(i)
123                    // Optimization: just check if next one is Relu(i)
124                    if i + 1 < n {
125                        if let NodeType::Relu(inp) = graph.nodes[i + 1].op {
126                            if inp == i {
127                                // Found Fusion!
128                                let new_id = new_nodes.len();
129                                mapping.insert(i + 1, new_id); // Relu maps to Fused
130                                                               // Conv2d node maps to Fused?
131                                                               // Actually the output of Relu is the output of Fused.
132                                                               // The output of Conv2d is consumed.
133
134                                // Remap inputs
135                                let new_input = *mapping.get(input_id).unwrap_or(input_id);
136                                let new_weight = *mapping.get(weight_id).unwrap_or(weight_id);
137
138                                new_nodes.push(Node {
139                                    op: NodeType::Conv2dRelu(
140                                        new_input, new_weight, *stride, *padding,
141                                    ),
142                                    shape: graph.nodes[i + 1].shape.clone(),
143                                    id: new_id,
144                                });
145
146                                consumed[i + 1] = true; // Skip Relu
147                                fused = true;
148                            }
149                        }
150                    }
151
152                    if !fused {
153                        // Copy Conv2d
154                        let new_id = new_nodes.len();
155                        mapping.insert(i, new_id);
156                        let new_input = *mapping.get(input_id).unwrap_or(input_id);
157                        let new_weight = *mapping.get(weight_id).unwrap_or(weight_id);
158
159                        new_nodes.push(Node {
160                            op: NodeType::Conv2d(new_input, new_weight, *stride, *padding),
161                            shape: node.shape.clone(),
162                            id: new_id,
163                        });
164                    }
165                }
166
167                // Generic copy for others
168                op => {
169                    let new_id = new_nodes.len();
170                    mapping.insert(i, new_id);
171
172                    // Remap inputs
173                    let new_op = match op {
174                        NodeType::Add(a, b) => NodeType::Add(
175                            *mapping.get(a).unwrap_or(a),
176                            *mapping.get(b).unwrap_or(b),
177                        ),
178                        NodeType::Mul(a, b) => NodeType::Mul(
179                            *mapping.get(a).unwrap_or(a),
180                            *mapping.get(b).unwrap_or(b),
181                        ),
182                        NodeType::Relu(a) => NodeType::Relu(*mapping.get(a).unwrap_or(a)),
183                        // ... copy others
184                        _ => op.clone(),
185                    };
186
187                    new_nodes.push(Node {
188                        op: new_op,
189                        shape: node.shape.clone(),
190                        id: new_id,
191                    });
192                }
193            }
194        }
195
196        graph.nodes = new_nodes;
197        // Remap outputs
198        for out in &mut graph.outputs {
199            if let Some(&new_id) = mapping.get(out) {
200                *out = new_id;
201            }
202        }
203        // Remap inputs (Node IDs)
204        for inp in &mut graph.inputs {
205            if let Some(&new_id) = mapping.get(inp) {
206                *inp = new_id;
207            }
208        }
209    }
210}
211
212// --- Executor ---
213pub struct Executor;
214
215impl Executor {
216    pub fn run(graph: &Graph, inputs: &[Tensor]) -> Vec<Tensor> {
217        let mut values: HashMap<usize, Tensor> = HashMap::new();
218
219        // Load inputs
220        for (i, &id) in graph.inputs.iter().enumerate() {
221            values.insert(id, inputs[i].clone());
222        }
223
224        for node in &graph.nodes {
225            if values.contains_key(&node.id) {
226                continue;
227            } // Already computed (Input/Weight)
228
229            let val = match &node.op {
230                NodeType::Input(_) => panic!("Input should be loaded"),
231                NodeType::Weight(t) => t.clone(),
232
233                NodeType::Add(a, b) => {
234                    let va = values.get(a).unwrap();
235                    let vb = values.get(b).unwrap();
236                    va.add(vb)
237                }
238                NodeType::Mul(a, b) => {
239                    let va = values.get(a).unwrap();
240                    let vb = values.get(b).unwrap();
241                    va.mul(vb)
242                }
243                NodeType::Relu(a) => {
244                    let va = values.get(a).unwrap();
245                    va.relu()
246                }
247                NodeType::Conv2d(inp, w, stride, padding) => {
248                    let va = values.get(inp).unwrap();
249                    let vw = values.get(w).unwrap();
250                    va.conv2d(vw, *stride, *padding)
251                }
252
253                // Fused Ops
254                NodeType::Conv2dRelu(inp, w, stride, padding) => {
255                    let va = values.get(inp).unwrap();
256                    let vw = values.get(w).unwrap();
257                    // In real XLA, this calls a fused kernel.
258                    // Here we emulate by calling conv then relu.
259                    // But we could dispatch to a specialized kernel if we had one.
260                    let conv = va.conv2d(vw, *stride, *padding);
261                    conv.relu()
262                }
263
264                _ => panic!("Op not implemented in executor"),
265            };
266
267            values.insert(node.id, val);
268        }
269
270        graph
271            .outputs
272            .iter()
273            .map(|id| values.get(id).unwrap().clone())
274            .collect()
275    }
276}