rustorch_core/jit.rs
1use crate::Tensor;
2use std::collections::HashMap;
3
4// --- IR Definition ---
5
6#[derive(Clone, Debug, PartialEq)]
7pub enum NodeType {
8 Input(usize), // Input index
9 Weight(Tensor), // Captured weight (constant)
10
11 // Ops
12 Add(usize, usize), // LHS, RHS node indices
13 Mul(usize, usize),
14 MatMul(usize, usize),
15 Relu(usize),
16 Conv2d(usize, usize, (usize, usize), (usize, usize)), // Input, Weight, Stride, Padding
17
18 // Fused Ops
19 Conv2dRelu(usize, usize, (usize, usize), (usize, usize)),
20 LinearRelu(usize, usize, usize), // Input, Weight, Bias (Optional?)
21}
22
23#[derive(Debug)]
24pub struct Node {
25 pub op: NodeType,
26 pub shape: Vec<usize>,
27 pub id: usize,
28 // dependencies, users, etc.
29}
30
31#[derive(Debug)]
32pub struct Graph {
33 pub nodes: Vec<Node>,
34 pub inputs: Vec<usize>,
35 pub outputs: Vec<usize>,
36}
37
38impl Default for Graph {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl Graph {
45 pub fn new() -> Self {
46 Self {
47 nodes: Vec::new(),
48 inputs: Vec::new(),
49 outputs: Vec::new(),
50 }
51 }
52
53 pub fn add_node(&mut self, op: NodeType, shape: Vec<usize>) -> usize {
54 let id = self.nodes.len();
55 self.nodes.push(Node { op, shape, id });
56 id
57 }
58
59 pub fn add_input(&mut self, shape: Vec<usize>) -> usize {
60 let id = self.add_node(NodeType::Input(self.inputs.len()), shape);
61 self.inputs.push(id);
62 id
63 }
64
65 pub fn add_weight(&mut self, tensor: Tensor) -> usize {
66 self.add_node(NodeType::Weight(tensor.clone()), tensor.shape().to_vec())
67 }
68}
69
70// --- Tracer ---
71// A simple tracer that records operations.
72// In a real framework, we would use a thread-local graph context or proxy tensors.
73// Here we simulate tracing by manually building graph or using a "TracedTensor" wrapper.
74
75// Let's implement a simple "Optimizer" pass.
76
77pub struct Optimizer;
78
79impl Optimizer {
80 pub fn optimize(graph: &mut Graph) {
81 Self::fuse_conv_relu(graph);
82 // Self::eliminate_dead_code(graph);
83 }
84
85 fn fuse_conv_relu(graph: &mut Graph) {
86 // Look for Conv2d -> Relu pattern
87 // This requires analyzing graph topology.
88 // For simplicity: Iterate nodes, if Relu(Conv2d(idx)), replace op.
89
90 // We can't easily modify Vec while iterating.
91 // And we need to redirect edges.
92 // Simplified approach: Build new graph.
93
94 let mut new_nodes = Vec::new();
95 let mut mapping = HashMap::new(); // Old ID -> New ID
96
97 // We iterate old nodes.
98 // If we see Conv2d, we look ahead? No, usually we look at Relu and check input.
99
100 // But to rebuild, we visit in topological order (which is index order here).
101
102 let n = graph.nodes.len();
103 let mut consumed = vec![false; n];
104
105 for i in 0..n {
106 if consumed[i] {
107 continue;
108 }
109
110 let node = &graph.nodes[i];
111
112 match &node.op {
113 NodeType::Conv2d(input_id, weight_id, stride, padding) => {
114 // Check if this node is used ONLY by a Relu
115 // If so, we can fuse.
116 // We need use-def chains.
117 // For this demo, let's peek ahead.
118 // If next node is Relu and takes this Conv2d as input, fuse.
119 // (This assumes linear ordering which is not guaranteed but common in sequential models)
120
121 let mut fused = false;
122 // Find if any future node is Relu(i)
123 // Optimization: just check if next one is Relu(i)
124 if i + 1 < n {
125 if let NodeType::Relu(inp) = graph.nodes[i + 1].op {
126 if inp == i {
127 // Found Fusion!
128 let new_id = new_nodes.len();
129 mapping.insert(i + 1, new_id); // Relu maps to Fused
130 // Conv2d node maps to Fused?
131 // Actually the output of Relu is the output of Fused.
132 // The output of Conv2d is consumed.
133
134 // Remap inputs
135 let new_input = *mapping.get(input_id).unwrap_or(input_id);
136 let new_weight = *mapping.get(weight_id).unwrap_or(weight_id);
137
138 new_nodes.push(Node {
139 op: NodeType::Conv2dRelu(
140 new_input, new_weight, *stride, *padding,
141 ),
142 shape: graph.nodes[i + 1].shape.clone(),
143 id: new_id,
144 });
145
146 consumed[i + 1] = true; // Skip Relu
147 fused = true;
148 }
149 }
150 }
151
152 if !fused {
153 // Copy Conv2d
154 let new_id = new_nodes.len();
155 mapping.insert(i, new_id);
156 let new_input = *mapping.get(input_id).unwrap_or(input_id);
157 let new_weight = *mapping.get(weight_id).unwrap_or(weight_id);
158
159 new_nodes.push(Node {
160 op: NodeType::Conv2d(new_input, new_weight, *stride, *padding),
161 shape: node.shape.clone(),
162 id: new_id,
163 });
164 }
165 }
166
167 // Generic copy for others
168 op => {
169 let new_id = new_nodes.len();
170 mapping.insert(i, new_id);
171
172 // Remap inputs
173 let new_op = match op {
174 NodeType::Add(a, b) => NodeType::Add(
175 *mapping.get(a).unwrap_or(a),
176 *mapping.get(b).unwrap_or(b),
177 ),
178 NodeType::Mul(a, b) => NodeType::Mul(
179 *mapping.get(a).unwrap_or(a),
180 *mapping.get(b).unwrap_or(b),
181 ),
182 NodeType::Relu(a) => NodeType::Relu(*mapping.get(a).unwrap_or(a)),
183 // ... copy others
184 _ => op.clone(),
185 };
186
187 new_nodes.push(Node {
188 op: new_op,
189 shape: node.shape.clone(),
190 id: new_id,
191 });
192 }
193 }
194 }
195
196 graph.nodes = new_nodes;
197 // Remap outputs
198 for out in &mut graph.outputs {
199 if let Some(&new_id) = mapping.get(out) {
200 *out = new_id;
201 }
202 }
203 // Remap inputs (Node IDs)
204 for inp in &mut graph.inputs {
205 if let Some(&new_id) = mapping.get(inp) {
206 *inp = new_id;
207 }
208 }
209 }
210}
211
212// --- Executor ---
213pub struct Executor;
214
215impl Executor {
216 pub fn run(graph: &Graph, inputs: &[Tensor]) -> Vec<Tensor> {
217 let mut values: HashMap<usize, Tensor> = HashMap::new();
218
219 // Load inputs
220 for (i, &id) in graph.inputs.iter().enumerate() {
221 values.insert(id, inputs[i].clone());
222 }
223
224 for node in &graph.nodes {
225 if values.contains_key(&node.id) {
226 continue;
227 } // Already computed (Input/Weight)
228
229 let val = match &node.op {
230 NodeType::Input(_) => panic!("Input should be loaded"),
231 NodeType::Weight(t) => t.clone(),
232
233 NodeType::Add(a, b) => {
234 let va = values.get(a).unwrap();
235 let vb = values.get(b).unwrap();
236 va.add(vb)
237 }
238 NodeType::Mul(a, b) => {
239 let va = values.get(a).unwrap();
240 let vb = values.get(b).unwrap();
241 va.mul(vb)
242 }
243 NodeType::Relu(a) => {
244 let va = values.get(a).unwrap();
245 va.relu()
246 }
247 NodeType::Conv2d(inp, w, stride, padding) => {
248 let va = values.get(inp).unwrap();
249 let vw = values.get(w).unwrap();
250 va.conv2d(vw, *stride, *padding)
251 }
252
253 // Fused Ops
254 NodeType::Conv2dRelu(inp, w, stride, padding) => {
255 let va = values.get(inp).unwrap();
256 let vw = values.get(w).unwrap();
257 // In real XLA, this calls a fused kernel.
258 // Here we emulate by calling conv then relu.
259 // But we could dispatch to a specialized kernel if we had one.
260 let conv = va.conv2d(vw, *stride, *padding);
261 conv.relu()
262 }
263
264 _ => panic!("Op not implemented in executor"),
265 };
266
267 values.insert(node.id, val);
268 }
269
270 graph
271 .outputs
272 .iter()
273 .map(|id| values.get(id).unwrap().clone())
274 .collect()
275 }
276}