Skip to main content

scirs2_neural/tracing/
executor.rs

1//! Graph executor for static computation graphs.
2//!
3//! Executes a `StaticGraph` in topological order, dispatching each operation
4//! to a pure-Rust implementation. Weights are supplied via a `weight_map`.
5//!
6//! Also provides `optimize()` which runs three passes:
7//! - **Constant folding** — evaluate nodes with no tensor inputs
8//! - **Dead node elimination** — remove nodes not reachable from graph outputs
9//! - **Operator fusion** — combine consecutive Linear→ReLU into FusedLinearReLU
10
11use crate::error::{Error, Result};
12use crate::tracing::types::{OpAttr, OpNode, OpType, StaticGraph};
13use std::collections::{HashMap, HashSet, VecDeque};
14
15// ---------------------------------------------------------------------------
16// Weight store key conventions
17// ---------------------------------------------------------------------------
18//
19// Weights for a `Linear` node with id N are stored as:
20//   "linear_{N}_weight"  shape: [out_features, in_features]
21//   "linear_{N}_bias"    shape: [out_features]
22//
23// For LayerNorm with id N:
24//   "layer_norm_{N}_gamma"  shape: [features]
25//   "layer_norm_{N}_beta"   shape: [features]
26
27// ---------------------------------------------------------------------------
28// GraphExecutor
29// ---------------------------------------------------------------------------
30
31/// Holds a `StaticGraph` and a weight map, and can execute the graph.
32pub struct GraphExecutor {
33    graph: StaticGraph,
34    /// Named float tensors (weights, biases, scales, etc.)
35    weight_map: HashMap<String, Vec<f64>>,
36}
37
38impl GraphExecutor {
39    /// Create an executor from a graph and its associated weights.
40    pub fn new(graph: StaticGraph, weight_map: HashMap<String, Vec<f64>>) -> Self {
41        Self { graph, weight_map }
42    }
43
44    /// Run the graph with the given input tensors (flat f64 slices, one per graph input).
45    ///
46    /// Returns the output tensors in the same order as `graph.output_node_ids`.
47    pub fn run(&self, inputs: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
48        // Validate input count
49        if inputs.len() != self.graph.input_node_ids.len() {
50            return Err(Error::InvalidArgument(format!(
51                "Expected {} inputs, got {}",
52                self.graph.input_node_ids.len(),
53                inputs.len()
54            )));
55        }
56
57        // Map: node_id → computed flat tensor
58        let mut tensor_cache: HashMap<usize, Vec<f64>> = HashMap::new();
59
60        // Seed with graph inputs
61        for (inp_tensor, &node_id) in inputs.iter().zip(self.graph.input_node_ids.iter()) {
62            tensor_cache.insert(node_id, inp_tensor.clone());
63        }
64
65        // Execute nodes in order (already topologically sorted by GraphBuilder)
66        for node in &self.graph.nodes {
67            // Skip placeholder input nodes (already in cache)
68            if node.op_type == OpType::Constant {
69                tensor_cache.entry(node.id).or_insert_with(|| {
70                    // Constant with no inputs: return zeros
71                    let n = node.output_spec.num_elements();
72                    vec![0.0_f64; n]
73                });
74                continue;
75            }
76
77            let output = self.execute_node(node, &tensor_cache)?;
78            tensor_cache.insert(node.id, output);
79        }
80
81        // Collect outputs
82        let mut results = Vec::with_capacity(self.graph.output_node_ids.len());
83        for &out_id in &self.graph.output_node_ids {
84            let tensor = tensor_cache
85                .get(&out_id)
86                .ok_or_else(|| {
87                    Error::InvalidArgument(format!("Output node {} not computed", out_id))
88                })?
89                .clone();
90            results.push(tensor);
91        }
92        Ok(results)
93    }
94
95    // -----------------------------------------------------------------------
96    // Per-operation dispatch
97    // -----------------------------------------------------------------------
98
99    fn execute_node(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
100        match &node.op_type {
101            OpType::Linear => self.exec_linear(node, cache),
102            OpType::ReLU => self.exec_elementwise(node, cache, |x| x.max(0.0)),
103            OpType::Sigmoid => self.exec_elementwise(node, cache, |x| 1.0 / (1.0 + (-x).exp())),
104            OpType::Tanh => self.exec_elementwise(node, cache, |x| x.tanh()),
105            OpType::Add => self.exec_binary(node, cache, |a, b| a + b),
106            OpType::Mul => self.exec_binary(node, cache, |a, b| a * b),
107            OpType::Reshape => self.exec_reshape(node, cache),
108            OpType::Softmax => self.exec_softmax(node, cache),
109            OpType::LayerNorm => self.exec_layer_norm(node, cache),
110            OpType::FusedLinearReLU => self.exec_fused_linear_relu(node, cache),
111            OpType::Transpose => self.exec_reshape(node, cache), // simplified
112            OpType::BatchNorm => {
113                // Simplified: treat as identity for executor tests
114                let inp_id = node
115                    .inputs
116                    .first()
117                    .ok_or_else(|| Error::InvalidArgument("BatchNorm has no inputs".to_string()))?;
118                Ok(cache
119                    .get(inp_id)
120                    .ok_or_else(|| Error::InvalidArgument(format!("Input {} not found", inp_id)))?
121                    .clone())
122            }
123            OpType::Conv1d => Err(Error::NotImplemented(
124                "Conv1d execution not yet implemented".to_string(),
125            )),
126            _ => Err(Error::NotImplemented(format!(
127                "OpType {:?} not implemented in executor",
128                node.op_type
129            ))),
130        }
131    }
132
133    fn get_input<'a>(
134        &self,
135        node: &OpNode,
136        idx: usize,
137        cache: &'a HashMap<usize, Vec<f64>>,
138    ) -> Result<&'a Vec<f64>> {
139        let node_id = node.inputs.get(idx).ok_or_else(|| {
140            Error::InvalidArgument(format!("Node {} has no input at index {}", node.id, idx))
141        })?;
142        cache.get(node_id).ok_or_else(|| {
143            Error::InvalidArgument(format!("Input tensor for node {} not in cache", node_id))
144        })
145    }
146
147    fn exec_elementwise(
148        &self,
149        node: &OpNode,
150        cache: &HashMap<usize, Vec<f64>>,
151        f: impl Fn(f64) -> f64,
152    ) -> Result<Vec<f64>> {
153        let input = self.get_input(node, 0, cache)?;
154        Ok(input.iter().map(|&x| f(x)).collect())
155    }
156
157    fn exec_binary(
158        &self,
159        node: &OpNode,
160        cache: &HashMap<usize, Vec<f64>>,
161        f: impl Fn(f64, f64) -> f64,
162    ) -> Result<Vec<f64>> {
163        let a = self.get_input(node, 0, cache)?;
164        let b = self.get_input(node, 1, cache)?;
165        if a.len() != b.len() {
166            return Err(Error::InvalidArgument(format!(
167                "Binary op shape mismatch: {} vs {}",
168                a.len(),
169                b.len()
170            )));
171        }
172        Ok(a.iter().zip(b.iter()).map(|(&av, &bv)| f(av, bv)).collect())
173    }
174
175    fn exec_reshape(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
176        let input = self.get_input(node, 0, cache)?;
177        // Reshape is a no-op on the flat data; just return a clone
178        Ok(input.clone())
179    }
180
181    fn exec_linear(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
182        let input = self.get_input(node, 0, cache)?;
183
184        let out_feat = get_attr_int(&node.attrs, "out_features")? as usize;
185        let in_feat = get_attr_int(&node.attrs, "in_features")? as usize;
186
187        // Infer batch size
188        if in_feat == 0 {
189            return Err(Error::InvalidArgument("in_features cannot be 0".into()));
190        }
191        let batch = input.len() / in_feat;
192        if batch * in_feat != input.len() {
193            return Err(Error::InvalidArgument(format!(
194                "Input length {} not divisible by in_features {}",
195                input.len(),
196                in_feat
197            )));
198        }
199
200        // Fetch weight and bias
201        let weight_key = format!("linear_{}_weight", node.id);
202        let bias_key = format!("linear_{}_bias", node.id);
203
204        let weight = self
205            .weight_map
206            .get(&weight_key)
207            .ok_or_else(|| Error::InvalidArgument(format!("Missing weight '{}'", weight_key)))?;
208        let bias = self
209            .weight_map
210            .get(&bias_key)
211            .ok_or_else(|| Error::InvalidArgument(format!("Missing bias '{}'", bias_key)))?;
212
213        if weight.len() != out_feat * in_feat {
214            return Err(Error::InvalidArgument(format!(
215                "Weight shape mismatch: expected {}×{}, got {}",
216                out_feat,
217                in_feat,
218                weight.len()
219            )));
220        }
221
222        // y[b, o] = Σ_i W[o, i] * x[b, i] + bias[o]
223        let mut output = vec![0.0_f64; batch * out_feat];
224        for b in 0..batch {
225            for o in 0..out_feat {
226                let mut acc = bias.get(o).copied().unwrap_or(0.0);
227                for i in 0..in_feat {
228                    acc += weight[o * in_feat + i] * input[b * in_feat + i];
229                }
230                output[b * out_feat + o] = acc;
231            }
232        }
233        Ok(output)
234    }
235
236    fn exec_fused_linear_relu(
237        &self,
238        node: &OpNode,
239        cache: &HashMap<usize, Vec<f64>>,
240    ) -> Result<Vec<f64>> {
241        let linear_out = self.exec_linear(node, cache)?;
242        Ok(linear_out.iter().map(|&x| x.max(0.0)).collect())
243    }
244
245    fn exec_softmax(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
246        let input = self.get_input(node, 0, cache)?;
247        let dim = get_attr_int(&node.attrs, "dim").unwrap_or(1) as usize;
248
249        // Determine row size along the softmax dimension
250        let shape = &node.output_spec.shape;
251        if shape.is_empty() {
252            return Ok(input.clone());
253        }
254
255        // Compute softmax over each "row" of size = shape[dim]
256        let row_size = if dim < shape.len() {
257            shape[dim]
258        } else {
259            input.len()
260        };
261        if row_size == 0 {
262            return Ok(input.clone());
263        }
264
265        let n_rows = input.len() / row_size;
266        let mut output = vec![0.0_f64; input.len()];
267
268        for r in 0..n_rows {
269            let row = &input[r * row_size..(r + 1) * row_size];
270            // Numerically stable: subtract max
271            let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
272            let exp_vals: Vec<f64> = row.iter().map(|&v| (v - max_val).exp()).collect();
273            let sum: f64 = exp_vals.iter().sum();
274            let sum_safe = if sum > 0.0 { sum } else { 1.0 };
275            for (i, &e) in exp_vals.iter().enumerate() {
276                output[r * row_size + i] = e / sum_safe;
277            }
278        }
279        Ok(output)
280    }
281
282    fn exec_layer_norm(&self, node: &OpNode, cache: &HashMap<usize, Vec<f64>>) -> Result<Vec<f64>> {
283        let input = self.get_input(node, 0, cache)?;
284        let eps = node
285            .attrs
286            .get("eps")
287            .and_then(|a| match a {
288                OpAttr::Float(f) => Some(*f),
289                _ => None,
290            })
291            .unwrap_or(1e-5);
292
293        let shape = &node.output_spec.shape;
294        let last_dim = shape.last().copied().unwrap_or(input.len());
295        if last_dim == 0 {
296            return Ok(input.clone());
297        }
298
299        let n_rows = input.len() / last_dim;
300
301        // Fetch optional affine parameters
302        let gamma_key = format!("layer_norm_{}_gamma", node.id);
303        let beta_key = format!("layer_norm_{}_beta", node.id);
304        let gamma = self.weight_map.get(&gamma_key);
305        let beta = self.weight_map.get(&beta_key);
306
307        let mut output = vec![0.0_f64; input.len()];
308        for r in 0..n_rows {
309            let row = &input[r * last_dim..(r + 1) * last_dim];
310            let mean = row.iter().sum::<f64>() / last_dim as f64;
311            let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / last_dim as f64;
312            let std_inv = 1.0 / (var + eps).sqrt();
313            for (i, &v) in row.iter().enumerate() {
314                let normalized = (v - mean) * std_inv;
315                let scaled = gamma.and_then(|g| g.get(i).copied()).unwrap_or(1.0) * normalized;
316                let shifted = scaled + beta.and_then(|b| b.get(i).copied()).unwrap_or(0.0);
317                output[r * last_dim + i] = shifted;
318            }
319        }
320        Ok(output)
321    }
322}
323
324// ---------------------------------------------------------------------------
325// Attribute helpers
326// ---------------------------------------------------------------------------
327
328fn get_attr_int(attrs: &HashMap<String, OpAttr>, key: &str) -> Result<i64> {
329    match attrs.get(key) {
330        Some(OpAttr::Int(v)) => Ok(*v),
331        Some(_) => Err(Error::InvalidArgument(format!(
332            "Attribute '{}' is not an integer",
333            key
334        ))),
335        None => Err(Error::InvalidArgument(format!(
336            "Missing attribute '{}'",
337            key
338        ))),
339    }
340}
341
342// ---------------------------------------------------------------------------
343// Graph optimization passes
344// ---------------------------------------------------------------------------
345
346/// Apply optimization passes to a `StaticGraph` and return the optimized graph.
347///
348/// Passes applied (in order):
349/// 1. Dead node elimination — remove nodes not reachable backward from outputs
350/// 2. Constant folding — pre-evaluate `Constant` nodes with no dependencies
351/// 3. Operator fusion — fuse Linear→ReLU pairs into `FusedLinearReLU`
352pub fn optimize(graph: &StaticGraph) -> StaticGraph {
353    let after_dne = dead_node_elimination(graph);
354
355    operator_fusion(&after_dne)
356}
357
358/// Remove nodes that are not on any path from a graph input to a graph output.
359fn dead_node_elimination(graph: &StaticGraph) -> StaticGraph {
360    // BFS backward from output nodes
361    let mut live: HashSet<usize> = HashSet::new();
362    let mut queue: VecDeque<usize> = VecDeque::new();
363
364    for &out_id in &graph.output_node_ids {
365        if !live.contains(&out_id) {
366            live.insert(out_id);
367            queue.push_back(out_id);
368        }
369    }
370
371    // Build reverse-edge map: node_id → set of nodes that produce its inputs
372    let mut producers: HashMap<usize, Vec<usize>> = HashMap::new();
373    for node in &graph.nodes {
374        for &inp_id in &node.inputs {
375            producers.entry(node.id).or_default().push(inp_id);
376        }
377    }
378
379    while let Some(id) = queue.pop_front() {
380        for &prod_id in producers.get(&id).unwrap_or(&vec![]) {
381            if !live.contains(&prod_id) {
382                live.insert(prod_id);
383                queue.push_back(prod_id);
384            }
385        }
386    }
387
388    // Keep only live nodes (preserve original order)
389    let kept_nodes: Vec<OpNode> = graph
390        .nodes
391        .iter()
392        .filter(|n| live.contains(&n.id))
393        .cloned()
394        .collect();
395
396    let mut id_to_idx = HashMap::new();
397    for (idx, node) in kept_nodes.iter().enumerate() {
398        id_to_idx.insert(node.id, idx);
399    }
400
401    let mut new_graph = StaticGraph::new(graph.inputs.clone(), graph.outputs.clone());
402    new_graph.nodes = kept_nodes;
403    new_graph.id_to_idx = id_to_idx;
404    new_graph.input_node_ids = graph.input_node_ids.clone();
405    new_graph.output_node_ids = graph.output_node_ids.clone();
406    new_graph
407}
408
409/// Fold constants: Constant nodes with no tensor inputs and known values
410/// are pre-evaluated (here we mark them; actual value injection happens at
411/// executor time via the weight_map).
412///
413/// In this implementation, constant folding is a no-op pass since Constant
414/// nodes are placeholders — the real benefit is tracked in the node metadata.
415/// This function serves as the hook for future compile-time constant propagation.
416fn _constant_folding(graph: &StaticGraph) -> StaticGraph {
417    // Currently a pass-through; constant values are resolved at execution time
418    graph.clone()
419}
420
421/// Fuse consecutive Linear→ReLU nodes into a single `FusedLinearReLU` node.
422fn operator_fusion(graph: &StaticGraph) -> StaticGraph {
423    let mut fused_nodes = graph.nodes.clone();
424
425    // Find (Linear, ReLU) pairs where the ReLU has exactly one consumer of the Linear
426    let mut to_fuse: Vec<(usize, usize)> = Vec::new(); // (linear_idx, relu_idx)
427    for (relu_idx, node) in fused_nodes.iter().enumerate() {
428        if node.op_type != OpType::ReLU {
429            continue;
430        }
431        let relu_input_id = match node.inputs.first() {
432            Some(&id) => id,
433            None => continue,
434        };
435        // Check if the input is a Linear node
436        let linear_idx = match fused_nodes
437            .iter()
438            .position(|n| n.id == relu_input_id && n.op_type == OpType::Linear)
439        {
440            Some(i) => i,
441            None => continue,
442        };
443        // Check the Linear node has exactly one consumer (this ReLU)
444        let linear_output_count = fused_nodes
445            .iter()
446            .filter(|n| n.inputs.contains(&relu_input_id))
447            .count();
448        if linear_output_count == 1 {
449            to_fuse.push((linear_idx, relu_idx));
450        }
451    }
452
453    // Apply fusions: replace Linear with FusedLinearReLU, mark ReLU for removal
454    let mut remove_ids: HashSet<usize> = HashSet::new();
455    let mut relu_id_to_linear_id: HashMap<usize, usize> = HashMap::new();
456
457    for (linear_idx, relu_idx) in to_fuse {
458        let relu_id = fused_nodes[relu_idx].id;
459        let linear_id = fused_nodes[linear_idx].id;
460
461        // Change the Linear to FusedLinearReLU
462        fused_nodes[linear_idx].op_type = OpType::FusedLinearReLU;
463        // The fused node should have the same output spec as the ReLU
464        // (they're identical for ReLU)
465        remove_ids.insert(relu_id);
466        relu_id_to_linear_id.insert(relu_id, linear_id);
467    }
468
469    // Update all references to removed ReLU nodes to point to the fused Linear
470    for node in &mut fused_nodes {
471        for inp_id in &mut node.inputs {
472            if let Some(&fused_id) = relu_id_to_linear_id.get(inp_id) {
473                *inp_id = fused_id;
474            }
475        }
476    }
477
478    // Remove the now-redundant ReLU nodes
479    fused_nodes.retain(|n| !remove_ids.contains(&n.id));
480
481    // Rebuild the graph
482    let mut id_to_idx = HashMap::new();
483    for (idx, node) in fused_nodes.iter().enumerate() {
484        id_to_idx.insert(node.id, idx);
485    }
486
487    // Update output_node_ids if any output pointed to a fused ReLU
488    let output_node_ids: Vec<usize> = graph
489        .output_node_ids
490        .iter()
491        .map(|&id| *relu_id_to_linear_id.get(&id).unwrap_or(&id))
492        .collect();
493
494    let mut new_graph = StaticGraph::new(graph.inputs.clone(), graph.outputs.clone());
495    new_graph.nodes = fused_nodes;
496    new_graph.id_to_idx = id_to_idx;
497    new_graph.input_node_ids = graph.input_node_ids.clone();
498    new_graph.output_node_ids = output_node_ids;
499    new_graph
500}
501
502// ---------------------------------------------------------------------------
503// Tests
504// ---------------------------------------------------------------------------
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::tracing::graph_builder::GraphBuilder;
510    use crate::tracing::types::{DType, TensorSpec};
511
512    /// Build a simple weight map for a single linear layer.
513    fn linear_weights(node_id: usize, in_f: usize, out_f: usize) -> HashMap<String, Vec<f64>> {
514        let mut map = HashMap::new();
515        // Identity-like weight: each output equals the first input
516        let mut weight = vec![0.0_f64; out_f * in_f];
517        for o in 0..out_f.min(in_f) {
518            weight[o * in_f + o] = 1.0;
519        }
520        map.insert(format!("linear_{}_weight", node_id), weight);
521        map.insert(format!("linear_{}_bias", node_id), vec![0.0_f64; out_f]);
522        map
523    }
524
525    #[test]
526    fn test_executor_linear_relu() {
527        let mut builder = GraphBuilder::new();
528        let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
529        let h = builder.linear(input, 4, 4);
530        let out = builder.relu(h);
531        let graph = builder.build(vec![out]);
532
533        // Find node IDs
534        let linear_id = graph
535            .nodes
536            .iter()
537            .find(|n| n.op_type == OpType::Linear)
538            .map(|n| n.id)
539            .expect("test: linear node");
540
541        let mut weights = linear_weights(linear_id, 4, 4);
542        // Set negative bias so some outputs will be negative before ReLU
543        weights.insert(
544            format!("linear_{}_bias", linear_id),
545            vec![-1.0, -1.0, 1.0, 1.0],
546        );
547
548        let executor = GraphExecutor::new(graph, weights);
549        let result = executor
550            .run(&[vec![1.0, 2.0, 3.0, 4.0]])
551            .expect("test: run");
552        assert_eq!(result.len(), 1);
553        let out = &result[0];
554        assert_eq!(out.len(), 4);
555        // ReLU: all outputs >= 0
556        for &v in out {
557            assert!(v >= 0.0, "ReLU output must be >= 0, got {v}");
558        }
559    }
560
561    #[test]
562    fn test_executor_softmax_sums_one() {
563        let mut builder = GraphBuilder::new();
564        let input = builder.input(TensorSpec::new(vec![1, 5], DType::F64));
565        let out = builder.softmax(input, 1);
566        let graph = builder.build(vec![out]);
567
568        let executor = GraphExecutor::new(graph, HashMap::new());
569        let result = executor
570            .run(&[vec![1.0, 2.0, 3.0, 4.0, 5.0]])
571            .expect("test: run softmax");
572        let out = &result[0];
573        assert_eq!(out.len(), 5);
574        let sum: f64 = out.iter().sum();
575        assert!(
576            (sum - 1.0).abs() < 1e-9,
577            "Softmax should sum to 1, got {sum}"
578        );
579    }
580
581    #[test]
582    fn test_executor_layer_norm() {
583        let mut builder = GraphBuilder::new();
584        let input = builder.input(TensorSpec::new(vec![1, 8], DType::F64));
585        let out = builder.layer_norm(input, 1e-5);
586        let graph = builder.build(vec![out]);
587
588        let executor = GraphExecutor::new(graph, HashMap::new());
589        let data: Vec<f64> = (0..8).map(|i| i as f64).collect();
590        let result = executor.run(&[data]).expect("test: run layer_norm");
591        let out = &result[0];
592        assert_eq!(out.len(), 8);
593        let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
594        let var: f64 = out.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / out.len() as f64;
595        assert!(mean.abs() < 1e-6, "LayerNorm mean should be ~0, got {mean}");
596        assert!(
597            (var - 1.0).abs() < 1e-4,
598            "LayerNorm variance should be ~1, got {var}"
599        );
600    }
601
602    #[test]
603    fn test_graph_dead_node_elimination() {
604        let mut builder = GraphBuilder::new();
605        let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
606        let h1 = builder.linear(input, 4, 4); // used
607        let _dead = builder.relu(input); // dead (not used)
608        let graph = builder.build(vec![h1]);
609
610        let before_count = graph.num_nodes();
611        let optimized = dead_node_elimination(&graph);
612        // The dead ReLU branch should be removed
613        assert!(
614            optimized.num_nodes() < before_count,
615            "Dead node elimination should reduce node count: before={before_count}, after={}",
616            optimized.num_nodes()
617        );
618    }
619
620    #[test]
621    fn test_graph_constant_folding() {
622        // Constant folding here means the optimization pass completes without
623        // error and constant nodes remain in the graph
624        let mut builder = GraphBuilder::new();
625        let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
626        let out = builder.linear(input, 4, 2);
627        let graph = builder.build(vec![out]);
628
629        let optimized = optimize(&graph);
630        // Should still have nodes
631        assert!(optimized.num_nodes() > 0);
632    }
633
634    #[test]
635    fn test_operator_fusion() {
636        let mut builder = GraphBuilder::new();
637        let input = builder.input(TensorSpec::new(vec![1, 4], DType::F64));
638        let linear_out = builder.linear(input, 4, 4);
639        let relu_out = builder.relu(linear_out);
640        let graph = builder.build(vec![relu_out]);
641
642        let before_count = graph.num_nodes();
643        let fused = operator_fusion(&graph);
644        // Linear + ReLU should become one FusedLinearReLU
645        assert!(
646            fused.num_nodes() < before_count,
647            "Fusion should reduce node count: before={before_count}, after={}",
648            fused.num_nodes()
649        );
650        let has_fused = fused
651            .nodes
652            .iter()
653            .any(|n| n.op_type == OpType::FusedLinearReLU);
654        assert!(
655            has_fused,
656            "Graph should contain FusedLinearReLU after fusion"
657        );
658    }
659
660    #[test]
661    fn test_static_graph_shapes_consistent() {
662        let mut builder = GraphBuilder::new();
663        let input = builder.input(TensorSpec::new(vec![1, 16], DType::F64));
664        let h1 = builder.linear(input, 16, 8);
665        let h2 = builder.relu(h1);
666        let out = builder.linear(h2, 8, 4);
667        let graph = builder.build(vec![out]);
668
669        // Verify output shapes are consistent along the chain
670        let linear_out_shapes: Vec<Vec<usize>> = graph
671            .nodes
672            .iter()
673            .filter(|n| n.op_type == OpType::Linear)
674            .map(|n| n.output_spec.shape.clone())
675            .collect();
676
677        // First linear: [1, 8], second: [1, 4]
678        assert_eq!(linear_out_shapes[0], vec![1, 8]);
679        assert_eq!(linear_out_shapes[1], vec![1, 4]);
680    }
681}