1use crate::error::{Error, Result};
12use crate::tracing::types::{OpAttr, OpNode, OpType, StaticGraph};
13use std::collections::{HashMap, HashSet, VecDeque};
14
15pub struct GraphExecutor {
33 graph: StaticGraph,
34 weight_map: HashMap<String, Vec<f64>>,
36}
37
38impl GraphExecutor {
39 pub fn new(graph: StaticGraph, weight_map: HashMap<String, Vec<f64>>) -> Self {
41 Self { graph, weight_map }
42 }
43
44 pub fn run(&self, inputs: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
48 if inputs.len() != self.graph.input_node_ids.len() {
50 return Err(Error::InvalidArgument(format!(
51 "Expected {} inputs, got {}",
52 self.graph.input_node_ids.len(),
53 inputs.len()
54 )));
55 }
56
57 let mut tensor_cache: HashMap<usize, Vec<f64>> = HashMap::new();
59
60 for (inp_tensor, &node_id) in inputs.iter().zip(self.graph.input_node_ids.iter()) {
62 tensor_cache.insert(node_id, inp_tensor.clone());
63 }
64
65 for node in &self.graph.nodes {
67 if node.op_type == OpType::Constant {
69 tensor_cache.entry(node.id).or_insert_with(|| {
70 let n = node.output_spec.num_elements();
72 vec![0.0_f64; n]
73 });
74 continue;
75 }
76
77 let output = self.execute_node(node, &tensor_cache)?;
78 tensor_cache.insert(node.id, output);
79 }
80
81 let mut results = Vec::with_capacity(self.graph.output_node_ids.len());
83 for &out_id in &self.graph.output_node_ids {
84 let tensor = tensor_cache
85 .get(&out_id)
86 .ok_or_else(|| {
87 Error::InvalidArgument(format!("Output node {} not computed", out_id))
88 })?
89 .clone();
90 results.push(tensor);
91 }
92 Ok(results)
93 }
94
95 fn execute_node(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
100 match &node.op_type {
101 OpType::Linear => self.exec_linear(node, cache),
102 OpType::ReLU => self.exec_elementwise(node, cache, |x| x.max(0.0)),
103 OpType::Sigmoid => self.exec_elementwise(node, cache, |x| 1.0 / (1.0 + (-x).exp())),
104 OpType::Tanh => self.exec_elementwise(node, cache, |x| x.tanh()),
105 OpType::Add => self.exec_binary(node, cache, |a, b| a + b),
106 OpType::Mul => self.exec_binary(node, cache, |a, b| a * b),
107 OpType::Reshape => self.exec_reshape(node, cache),
108 OpType::Softmax => self.exec_softmax(node, cache),
109 OpType::LayerNorm => self.exec_layer_norm(node, cache),
110 OpType::FusedLinearReLU => self.exec_fused_linear_relu(node, cache),
111 OpType::Transpose => self.exec_reshape(node, cache), OpType::BatchNorm => {
113 let inp_id = node
115 .inputs
116 .first()
117 .ok_or_else(|| Error::InvalidArgument("BatchNorm has no inputs".to_string()))?;
118 Ok(cache
119 .get(inp_id)
120 .ok_or_else(|| Error::InvalidArgument(format!("Input {} not found", inp_id)))?
121 .clone())
122 }
123 OpType::Conv1d => Err(Error::NotImplemented(
124 "Conv1d execution not yet implemented".to_string(),
125 )),
126 _ => Err(Error::NotImplemented(format!(
127 "OpType {:?} not implemented in executor",
128 node.op_type
129 ))),
130 }
131 }
132
133 fn get_input<'a>(
134 &self,
135 node: &OpNode,
136 idx: usize,
137 cache: &'a HashMap<usize, Vec<f64>>,
138 ) -> Result<&'a Vec<f64>> {
139 let node_id = node.inputs.get(idx).ok_or_else(|| {
140 Error::InvalidArgument(format!("Node {} has no input at index {}", node.id, idx))
141 })?;
142 cache.get(node_id).ok_or_else(|| {
143 Error::InvalidArgument(format!("Input tensor for node {} not in cache", node_id))
144 })
145 }
146
147 fn exec_elementwise(
148 &self,
149 node: &OpNode,
150 cache: &HashMap<usize, Vec<f64>>,
151 f: impl Fn(f64) -> f64,
152 ) -> Result<Vec<f64>> {
153 let input = self.get_input(node, 0, cache)?;
154 Ok(input.iter().map(|&x| f(x)).collect())
155 }
156
157 fn exec_binary(
158 &self,
159 node: &OpNode,
160 cache: &HashMap<usize, Vec<f64>>,
161 f: impl Fn(f64, f64) -> f64,
162 ) -> Result<Vec<f64>> {
163 let a = self.get_input(node, 0, cache)?;
164 let b = self.get_input(node, 1, cache)?;
165 if a.len() != b.len() {
166 return Err(Error::InvalidArgument(format!(
167 "Binary op shape mismatch: {} vs {}",
168 a.len(),
169 b.len()
170 )));
171 }
172 Ok(a.iter().zip(b.iter()).map(|(&av, &bv)| f(av, bv)).collect())
173 }
174
175 fn exec_reshape(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
176 let input = self.get_input(node, 0, cache)?;
177 Ok(input.clone())
179 }
180
181 fn exec_linear(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
182 let input = self.get_input(node, 0, cache)?;
183
184 let out_feat = get_attr_int(&node.attrs, "out_features")? as usize;
185 let in_feat = get_attr_int(&node.attrs, "in_features")? as usize;
186
187 if in_feat == 0 {
189 return Err(Error::InvalidArgument("in_features cannot be 0".into()));
190 }
191 let batch = input.len() / in_feat;
192 if batch * in_feat != input.len() {
193 return Err(Error::InvalidArgument(format!(
194 "Input length {} not divisible by in_features {}",
195 input.len(),
196 in_feat
197 )));
198 }
199
200 let weight_key = format!("linear_{}_weight", node.id);
202 let bias_key = format!("linear_{}_bias", node.id);
203
204 let weight = self
205 .weight_map
206 .get(&weight_key)
207 .ok_or_else(|| Error::InvalidArgument(format!("Missing weight '{}'", weight_key)))?;
208 let bias = self
209 .weight_map
210 .get(&bias_key)
211 .ok_or_else(|| Error::InvalidArgument(format!("Missing bias '{}'", bias_key)))?;
212
213 if weight.len() != out_feat * in_feat {
214 return Err(Error::InvalidArgument(format!(
215 "Weight shape mismatch: expected {}×{}, got {}",
216 out_feat,
217 in_feat,
218 weight.len()
219 )));
220 }
221
222 let mut output = vec![0.0_f64; batch * out_feat];
224 for b in 0..batch {
225 for o in 0..out_feat {
226 let mut acc = bias.get(o).copied().unwrap_or(0.0);
227 for i in 0..in_feat {
228 acc += weight[o * in_feat + i] * input[b * in_feat + i];
229 }
230 output[b * out_feat + o] = acc;
231 }
232 }
233 Ok(output)
234 }
235
236 fn exec_fused_linear_relu(
237 &self,
238 node: &OpNode,
239 cache: &HashMap<usize, Vec<f64>>,
240 ) -> Result<Vec<f64>> {
241 let linear_out = self.exec_linear(node, cache)?;
242 Ok(linear_out.iter().map(|&x| x.max(0.0)).collect())
243 }
244
245 fn exec_softmax(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
246 let input = self.get_input(node, 0, cache)?;
247 let dim = get_attr_int(&node.attrs, "dim").unwrap_or(1) as usize;
248
249 let shape = &node.output_spec.shape;
251 if shape.is_empty() {
252 return Ok(input.clone());
253 }
254
255 let row_size = if dim < shape.len() {
257 shape[dim]
258 } else {
259 input.len()
260 };
261 if row_size == 0 {
262 return Ok(input.clone());
263 }
264
265 let n_rows = input.len() / row_size;
266 let mut output = vec![0.0_f64; input.len()];
267
268 for r in 0..n_rows {
269 let row = &input[r * row_size..(r + 1) * row_size];
270 let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
272 let exp_vals: Vec<f64> = row.iter().map(|&v| (v - max_val).exp()).collect();
273 let sum: f64 = exp_vals.iter().sum();
274 let sum_safe = if sum > 0.0 { sum } else { 1.0 };
275 for (i, &e) in exp_vals.iter().enumerate() {
276 output[r * row_size + i] = e / sum_safe;
277 }
278 }
279 Ok(output)
280 }
281
282 fn exec_layer_norm(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
283 let input = self.get_input(node, 0, cache)?;
284 let eps = node
285 .attrs
286 .get("eps")
287 .and_then(|a| match a {
288 OpAttr::Float(f) => Some(*f),
289 _ => None,
290 })
291 .unwrap_or(1e-5);
292
293 let shape = &node.output_spec.shape;
294 let last_dim = shape.last().copied().unwrap_or(input.len());
295 if last_dim == 0 {
296 return Ok(input.clone());
297 }
298
299 let n_rows = input.len() / last_dim;
300
301 let gamma_key = format!("layer_norm_{}_gamma", node.id);
303 let beta_key = format!("layer_norm_{}_beta", node.id);
304 let gamma = self.weight_map.get(&gamma_key);
305 let beta = self.weight_map.get(&beta_key);
306
307 let mut output = vec![0.0_f64; input.len()];
308 for r in 0..n_rows {
309 let row = &input[r * last_dim..(r + 1) * last_dim];
310 let mean = row.iter().sum::<f64>() / last_dim as f64;
311 let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / last_dim as f64;
312 let std_inv = 1.0 / (var + eps).sqrt();
313 for (i, &v) in row.iter().enumerate() {
314 let normalized = (v - mean) * std_inv;
315 let scaled = gamma.and_then(|g| g.get(i).copied()).unwrap_or(1.0) * normalized;
316 let shifted = scaled + beta.and_then(|b| b.get(i).copied()).unwrap_or(0.0);
317 output[r * last_dim + i] = shifted;
318 }
319 }
320 Ok(output)
321 }
322}
323
324fn get_attr_int(attrs: &HashMap<String, OpAttr>, key: &str) -> Result<i64> {
329 match attrs.get(key) {
330 Some(OpAttr::Int(v)) => Ok(*v),
331 Some(_) => Err(Error::InvalidArgument(format!(
332 "Attribute '{}' is not an integer",
333 key
334 ))),
335 None => Err(Error::InvalidArgument(format!(
336 "Missing attribute '{}'",
337 key
338 ))),
339 }
340}
341
342pub fn optimize(graph: &StaticGraph) -> StaticGraph {
353 let after_dne = dead_node_elimination(graph);
354
355 operator_fusion(&after_dne)
356}
357
358fn dead_node_elimination(graph: &StaticGraph) -> StaticGraph {
360 let mut live: HashSet<usize> = HashSet::new();
362 let mut queue: VecDeque<usize> = VecDeque::new();
363
364 for &out_id in &graph.output_node_ids {
365 if !live.contains(&out_id) {
366 live.insert(out_id);
367 queue.push_back(out_id);
368 }
369 }
370
371 let mut producers: HashMap<usize, Vec<usize>> = HashMap::new();
373 for node in &graph.nodes {
374 for &inp_id in &node.inputs {
375 producers.entry(node.id).or_default().push(inp_id);
376 }
377 }
378
379 while let Some(id) = queue.pop_front() {
380 for &prod_id in producers.get(&id).unwrap_or(&vec![]) {
381 if !live.contains(&prod_id) {
382 live.insert(prod_id);
383 queue.push_back(prod_id);
384 }
385 }
386 }
387
388 let kept_nodes: Vec<OpNode> = graph
390 .nodes
391 .iter()
392 .filter(|n| live.contains(&n.id))
393 .cloned()
394 .collect();
395
396 let mut id_to_idx = HashMap::new();
397 for (idx, node) in kept_nodes.iter().enumerate() {
398 id_to_idx.insert(node.id, idx);
399 }
400
401 let mut new_graph = StaticGraph::new(graph.inputs.clone(), graph.outputs.clone());
402 new_graph.nodes = kept_nodes;
403 new_graph.id_to_idx = id_to_idx;
404 new_graph.input_node_ids = graph.input_node_ids.clone();
405 new_graph.output_node_ids = graph.output_node_ids.clone();
406 new_graph
407}
408
409fn _constant_folding(graph: &StaticGraph) -> StaticGraph {
417 graph.clone()
419}
420
421fn operator_fusion(graph: &StaticGraph) -> StaticGraph {
423 let mut fused_nodes = graph.nodes.clone();
424
425 let mut to_fuse: Vec<(usize, usize)> = Vec::new(); for (relu_idx, node) in fused_nodes.iter().enumerate() {
428 if node.op_type != OpType::ReLU {
429 continue;
430 }
431 let relu_input_id = match node.inputs.first() {
432 Some(&id) => id,
433 None => continue,
434 };
435 let linear_idx = match fused_nodes
437 .iter()
438 .position(|n| n.id == relu_input_id && n.op_type == OpType::Linear)
439 {
440 Some(i) => i,
441 None => continue,
442 };
443 let linear_output_count = fused_nodes
445 .iter()
446 .filter(|n| n.inputs.contains(&relu_input_id))
447 .count();
448 if linear_output_count == 1 {
449 to_fuse.push((linear_idx, relu_idx));
450 }
451 }
452
453 let mut remove_ids: HashSet<usize> = HashSet::new();
455 let mut relu_id_to_linear_id: HashMap<usize, usize> = HashMap::new();
456
457 for (linear_idx, relu_idx) in to_fuse {
458 let relu_id = fused_nodes[relu_idx].id;
459 let linear_id = fused_nodes[linear_idx].id;
460
461 fused_nodes[linear_idx].op_type = OpType::FusedLinearReLU;
463 remove_ids.insert(relu_id);
466 relu_id_to_linear_id.insert(relu_id, linear_id);
467 }
468
469 for node in &mut fused_nodes {
471 for inp_id in &mut node.inputs {
472 if let Some(&fused_id) = relu_id_to_linear_id.get(inp_id) {
473 *inp_id = fused_id;
474 }
475 }
476 }
477
478 fused_nodes.retain(|n| !remove_ids.contains(&n.id));
480
481 let mut id_to_idx = HashMap::new();
483 for (idx, node) in fused_nodes.iter().enumerate() {
484 id_to_idx.insert(node.id, idx);
485 }
486
487 let output_node_ids: Vec<usize> = graph
489 .output_node_ids
490 .iter()
491 .map(|&id| *relu_id_to_linear_id.get(&id).unwrap_or(&id))
492 .collect();
493
494 let mut new_graph = StaticGraph::new(graph.inputs.clone(), graph.outputs.clone());
495 new_graph.nodes = fused_nodes;
496 new_graph.id_to_idx = id_to_idx;
497 new_graph.input_node_ids = graph.input_node_ids.clone();
498 new_graph.output_node_ids = output_node_ids;
499 new_graph
500}
501
502#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::tracing::graph_builder::GraphBuilder;
510 use crate::tracing::types::{DType, TensorSpec};
511
512 fn linear_weights(node_id: usize, in_f: usize, out_f: usize) -> HashMap<String, Vec<f64>> {
514 let mut map = HashMap::new();
515 let mut weight = vec![0.0_f64; out_f * in_f];
517 for o in 0..out_f.min(in_f) {
518 weight[o * in_f + o] = 1.0;
519 }
520 map.insert(format!("linear_{}_weight", node_id), weight);
521 map.insert(format!("linear_{}_bias", node_id), vec![0.0_f64; out_f]);
522 map
523 }
524
525 #[test]
526 fn test_executor_linear_relu() {
527 let mut builder = GraphBuilder::new();
528 let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
529 let h = builder.linear(input, 4, 4);
530 let out = builder.relu(h);
531 let graph = builder.build(vec![out]);
532
533 let linear_id = graph
535 .nodes
536 .iter()
537 .find(|n| n.op_type == OpType::Linear)
538 .map(|n| n.id)
539 .expect("test: linear node");
540
541 let mut weights = linear_weights(linear_id, 4, 4);
542 weights.insert(
544 format!("linear_{}_bias", linear_id),
545 vec![-1.0, -1.0, 1.0, 1.0],
546 );
547
548 let executor = GraphExecutor::new(graph, weights);
549 let result = executor
550 .run(&[vec![1.0, 2.0, 3.0, 4.0]])
551 .expect("test: run");
552 assert_eq!(result.len(), 1);
553 let out = &result[0];
554 assert_eq!(out.len(), 4);
555 for &v in out {
557 assert!(v >= 0.0, "ReLU output must be >= 0, got {v}");
558 }
559 }
560
561 #[test]
562 fn test_executor_softmax_sums_one() {
563 let mut builder = GraphBuilder::new();
564 let input = builder.input(TensorSpec::new(vec![1, 5], DType::F64));
565 let out = builder.softmax(input, 1);
566 let graph = builder.build(vec![out]);
567
568 let executor = GraphExecutor::new(graph, HashMap::new());
569 let result = executor
570 .run(&[vec![1.0, 2.0, 3.0, 4.0, 5.0]])
571 .expect("test: run softmax");
572 let out = &result[0];
573 assert_eq!(out.len(), 5);
574 let sum: f64 = out.iter().sum();
575 assert!(
576 (sum - 1.0).abs() < 1e-9,
577 "Softmax should sum to 1, got {sum}"
578 );
579 }
580
581 #[test]
582 fn test_executor_layer_norm() {
583 let mut builder = GraphBuilder::new();
584 let input = builder.input(TensorSpec::new(vec![1, 8], DType::F64));
585 let out = builder.layer_norm(input, 1e-5);
586 let graph = builder.build(vec![out]);
587
588 let executor = GraphExecutor::new(graph, HashMap::new());
589 let data: Vec<f64> = (0..8).map(|i| i as f64).collect();
590 let result = executor.run(&[data]).expect("test: run layer_norm");
591 let out = &result[0];
592 assert_eq!(out.len(), 8);
593 let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
594 let var: f64 = out.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / out.len() as f64;
595 assert!(mean.abs() < 1e-6, "LayerNorm mean should be ~0, got {mean}");
596 assert!(
597 (var - 1.0).abs() < 1e-4,
598 "LayerNorm variance should be ~1, got {var}"
599 );
600 }
601
602 #[test]
603 fn test_graph_dead_node_elimination() {
604 let mut builder = GraphBuilder::new();
605 let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
606 let h1 = builder.linear(input, 4, 4); let _dead = builder.relu(input); let graph = builder.build(vec![h1]);
609
610 let before_count = graph.num_nodes();
611 let optimized = dead_node_elimination(&graph);
612 assert!(
614 optimized.num_nodes() < before_count,
615 "Dead node elimination should reduce node count: before={before_count}, after={}",
616 optimized.num_nodes()
617 );
618 }
619
620 #[test]
621 fn test_graph_constant_folding() {
622 let mut builder = GraphBuilder::new();
625 let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
626 let out = builder.linear(input, 4, 2);
627 let graph = builder.build(vec![out]);
628
629 let optimized = optimize(&graph);
630 assert!(optimized.num_nodes() > 0);
632 }
633
634 #[test]
635 fn test_operator_fusion() {
636 let mut builder = GraphBuilder::new();
637 let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
638 let linear_out = builder.linear(input, 4, 4);
639 let relu_out = builder.relu(linear_out);
640 let graph = builder.build(vec![relu_out]);
641
642 let before_count = graph.num_nodes();
643 let fused = operator_fusion(&graph);
644 assert!(
646 fused.num_nodes() < before_count,
647 "Fusion should reduce node count: before={before_count}, after={}",
648 fused.num_nodes()
649 );
650 let has_fused = fused
651 .nodes
652 .iter()
653 .any(|n| n.op_type == OpType::FusedLinearReLU);
654 assert!(
655 has_fused,
656 "Graph should contain FusedLinearReLU after fusion"
657 );
658 }
659
660 #[test]
661 fn test_static_graph_shapes_consistent() {
662 let mut builder = GraphBuilder::new();
663 let input = builder.input(TensorSpec::new(vec![1, 16], DType::F64));
664 let h1 = builder.linear(input, 16, 8);
665 let h2 = builder.relu(h1);
666 let out = builder.linear(h2, 8, 4);
667 let graph = builder.build(vec![out]);
668
669 let linear_out_shapes: Vec<Vec<usize>> = graph
671 .nodes
672 .iter()
673 .filter(|n| n.op_type == OpType::Linear)
674 .map(|n| n.output_spec.shape.clone())
675 .collect();
676
677 assert_eq!(linear_out_shapes[0], vec![1, 8]);
679 assert_eq!(linear_out_shapes[1], vec![1, 4]);
680 }
681}