Skip to main content

torsh_jit/graph/
core.rs

1//! Core graph representation structures
2
3use crate::{JitError, JitResult};
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use petgraph::Direction;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use torsh_core::{DType, DeviceType, Shape};
10
11pub use crate::graph::metadata::GraphMetadata;
12pub use crate::graph::operations::Operation;
13
14pub type NodeId = NodeIndex;
15
16/// Edge in the computation graph representing data flow between nodes
17#[derive(Debug, Clone, Default)]
18pub struct Edge {
19    /// Output index of the source node
20    pub src_output: usize,
21    /// Input index of the destination node
22    pub dst_input: usize,
23}
24
25/// Serializable wrapper for NodeIndex
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27pub struct SerializableNodeIndex(pub u32);
28
29impl From<NodeIndex> for SerializableNodeIndex {
30    fn from(node_index: NodeIndex) -> Self {
31        SerializableNodeIndex(node_index.index() as u32)
32    }
33}
34
35impl From<SerializableNodeIndex> for NodeIndex {
36    fn from(serializable: SerializableNodeIndex) -> Self {
37        NodeIndex::new(serializable.0 as usize)
38    }
39}
40
41/// A node in the computation graph
42#[derive(Debug, Clone)]
43pub struct Node {
44    /// Operation type
45    pub operation: Operation,
46
47    /// Node name/id
48    pub name: String,
49
50    /// Input shapes
51    pub input_shapes: Vec<Option<Shape>>,
52
53    /// Output shapes
54    pub output_shapes: Vec<Option<Shape>>,
55
56    /// Data types for outputs
57    pub dtypes: Vec<DType>,
58
59    /// Device information
60    pub device: DeviceType,
61
62    /// Additional attributes
63    pub attributes: HashMap<String, crate::graph::operations::Attribute>,
64
65    // Compatibility fields for existing code
66    /// Operation alias for compatibility
67    pub op: Operation,
68
69    /// Single dtype for compatibility (first dtype from dtypes vec)
70    pub dtype: DType,
71
72    /// Single output shape for compatibility (first shape from output_shapes vec)
73    pub output_shape: Shape,
74
75    /// Attributes alias for compatibility
76    pub attrs: HashMap<String, crate::graph::operations::Attribute>,
77
78    /// Input connections (placeholder for compatibility)
79    pub inputs: Vec<NodeId>,
80
81    /// Whether this is an output node (placeholder for compatibility)
82    pub is_output: bool,
83}
84
85impl Node {
86    /// Create a new node with the given operation
87    pub fn new(operation: Operation, name: String) -> Self {
88        let op = operation.clone();
89        let dtype = DType::F32; // Default dtype
90        let output_shape = Shape::new(vec![1]); // Default shape
91        let attributes = HashMap::new();
92
93        Self {
94            operation,
95            name,
96            input_shapes: Vec::new(),
97            output_shapes: Vec::new(),
98            dtypes: Vec::new(),
99            device: DeviceType::Cpu,
100            attributes: attributes.clone(),
101
102            // Compatibility fields
103            op,
104            dtype,
105            output_shape,
106            attrs: attributes,
107            inputs: Vec::new(),
108            is_output: false,
109        }
110    }
111
112    /// Set input shapes
113    pub fn with_input_shapes(mut self, shapes: Vec<Option<Shape>>) -> Self {
114        self.input_shapes = shapes;
115        self.sync_compatibility_fields();
116        self
117    }
118
119    /// Set output shapes
120    pub fn with_output_shapes(mut self, shapes: Vec<Option<Shape>>) -> Self {
121        self.output_shapes = shapes;
122        self.sync_compatibility_fields();
123        self
124    }
125
126    /// Set data types
127    pub fn with_dtypes(mut self, dtypes: Vec<DType>) -> Self {
128        self.dtypes = dtypes;
129        self.sync_compatibility_fields();
130        self
131    }
132
133    /// Set device
134    pub fn with_device(mut self, device: DeviceType) -> Self {
135        self.device = device;
136        self
137    }
138
139    /// Add an attribute
140    pub fn with_attribute(
141        mut self,
142        key: String,
143        value: crate::graph::operations::Attribute,
144    ) -> Self {
145        self.attributes.insert(key, value);
146        self.sync_compatibility_fields();
147        self
148    }
149
150    /// Get the number of inputs
151    pub fn num_inputs(&self) -> usize {
152        self.input_shapes.len()
153    }
154
155    /// Get the number of outputs
156    pub fn num_outputs(&self) -> usize {
157        self.output_shapes.len().max(1) // At least one output
158    }
159
160    /// Get input shape at index
161    pub fn input_shape(&self, index: usize) -> Option<&Shape> {
162        self.input_shapes.get(index).and_then(|s| s.as_ref())
163    }
164
165    /// Get output shape at index
166    pub fn output_shape(&self, index: usize) -> Option<&Shape> {
167        self.output_shapes.get(index).and_then(|s| s.as_ref())
168    }
169
170    /// Get data type at output index
171    pub fn dtype(&self, index: usize) -> Option<&DType> {
172        self.dtypes.get(index)
173    }
174
175    /// Check if this is an input node
176    pub fn is_input(&self) -> bool {
177        matches!(self.operation, Operation::Input | Operation::Parameter(_))
178    }
179
180    /// Check if this is a constant node
181    pub fn is_constant(&self) -> bool {
182        matches!(self.operation, Operation::Constant(_))
183    }
184
185    /// Check if this is a control flow node
186    pub fn is_control_flow(&self) -> bool {
187        matches!(
188            self.operation,
189            Operation::If(_)
190                | Operation::While(_)
191                | Operation::For(_)
192                | Operation::Break
193                | Operation::Continue
194                | Operation::Return(_)
195                | Operation::Block(_)
196                | Operation::Merge(_)
197        )
198    }
199
200    /// Get memory estimate in bytes
201    pub fn memory_estimate(&self) -> usize {
202        let mut total = 0;
203        for shape_opt in &self.output_shapes {
204            if let Some(shape) = shape_opt {
205                let elements = shape.dims().iter().product::<usize>();
206                // Assume each element is at least 4 bytes
207                total += elements * 4;
208            }
209        }
210        total
211    }
212
213    /// Get computational complexity estimate (FLOPs)
214    pub fn complexity_estimate(&self) -> usize {
215        match &self.operation {
216            Operation::MatMul | Operation::BatchMatMul => {
217                if self.input_shapes.len() >= 2 {
218                    if let (Some(Some(a_shape)), Some(Some(b_shape))) =
219                        (self.input_shapes.get(0), self.input_shapes.get(1))
220                    {
221                        // Matrix multiplication complexity: 2 * m * n * k
222                        if a_shape.dims().len() >= 2 && b_shape.dims().len() >= 2 {
223                            let m = a_shape.dims()[a_shape.dims().len() - 2];
224                            let k = a_shape.dims()[a_shape.dims().len() - 1];
225                            let n = b_shape.dims()[b_shape.dims().len() - 1];
226                            return 2 * m * n * k;
227                        }
228                    }
229                }
230                0
231            }
232            Operation::Conv2d(_) => {
233                // Simplified convolution complexity estimation
234                if let Some(Some(output_shape)) = self.output_shapes.get(0) {
235                    output_shape.dims().iter().product::<usize>() * 9 // 3x3 kernel approximation
236                } else {
237                    0
238                }
239            }
240            _ => {
241                // For other operations, estimate based on output size
242                if let Some(Some(output_shape)) = self.output_shapes.get(0) {
243                    output_shape.dims().iter().product::<usize>()
244                } else {
245                    1
246                }
247            }
248        }
249    }
250
251    /// Synchronize compatibility fields with main fields
252    pub fn sync_compatibility_fields(&mut self) {
253        self.op = self.operation.clone();
254        self.dtype = self.dtypes.first().copied().unwrap_or(DType::F32);
255        self.output_shape = self
256            .output_shapes
257            .first()
258            .and_then(|s| s.as_ref())
259            .cloned()
260            .unwrap_or_else(|| Shape::new(vec![1]));
261        self.attrs = self.attributes.clone();
262    }
263
264    /// Set attribute (compatibility method)
265    pub fn set_attribute(&mut self, key: String, value: crate::graph::operations::Attribute) {
266        self.attributes.insert(key.clone(), value.clone());
267        self.attrs.insert(key, value);
268    }
269
270    /// Set optimization hint (compatibility method)
271    pub fn set_optimization_hint(&mut self, hint: &str, value: &str) -> crate::JitResult<()> {
272        let attr_value = crate::graph::operations::Attribute::String(value.to_string());
273        self.set_attribute(hint.to_string(), attr_value);
274        Ok(())
275    }
276
277    /// Get attribute (compatibility method)
278    pub fn get_attribute(&self, key: &str) -> Option<&crate::graph::operations::Attribute> {
279        self.attributes.get(key)
280    }
281
282    /// Get operation type (compatibility method)
283    pub fn operation_type(&self) -> &str {
284        self.operation.as_str()
285    }
286
287    /// Check if node has side effects (compatibility method)
288    pub fn has_side_effects(&self) -> bool {
289        matches!(
290            self.operation,
291            Operation::Custom(_) | Operation::Break | Operation::Continue | Operation::Return(_)
292        )
293    }
294
295    /// Get operation category for optimization purposes
296    pub fn operation_category(&self) -> OperationCategory {
297        match &self.operation {
298            Operation::Add
299            | Operation::Sub
300            | Operation::Mul
301            | Operation::Div
302            | Operation::Neg
303            | Operation::Abs
304            | Operation::Exp
305            | Operation::Log
306            | Operation::Sqrt
307            | Operation::Sin
308            | Operation::Cos
309            | Operation::Tanh
310            | Operation::Sigmoid
311            | Operation::Relu
312            | Operation::Gelu
313            | Operation::Silu => OperationCategory::ElementWise,
314            Operation::MatMul | Operation::BatchMatMul => OperationCategory::LinearAlgebra,
315            Operation::Conv2d(_) | Operation::Linear(_) => OperationCategory::NeuralNetwork,
316            Operation::Sum { .. }
317            | Operation::Mean { .. }
318            | Operation::Max { .. }
319            | Operation::Min { .. } => OperationCategory::Reduction,
320            Operation::Reshape { .. }
321            | Operation::Transpose { .. }
322            | Operation::Squeeze { .. }
323            | Operation::Unsqueeze { .. }
324            | Operation::Slice { .. }
325            | Operation::Concat { .. } => OperationCategory::ShapeManipulation,
326            Operation::If(_)
327            | Operation::While(_)
328            | Operation::For(_)
329            | Operation::Break
330            | Operation::Continue
331            | Operation::Return(_)
332            | Operation::Block(_)
333            | Operation::Merge(_) => OperationCategory::ControlFlow,
334            Operation::Input | Operation::Parameter(_) | Operation::Constant(_) => {
335                OperationCategory::Input
336            }
337            _ => OperationCategory::Other,
338        }
339    }
340
341    /// Check if this operation can be vectorized using SIMD instructions
342    pub fn is_vectorizable(&self) -> bool {
343        match &self.operation {
344            // Element-wise operations are highly vectorizable
345            Operation::Add
346            | Operation::Sub
347            | Operation::Mul
348            | Operation::Div
349            | Operation::Neg
350            | Operation::Abs
351            | Operation::Exp
352            | Operation::Log
353            | Operation::Sqrt
354            | Operation::Sin
355            | Operation::Cos
356            | Operation::Tanh
357            | Operation::Sigmoid
358            | Operation::Relu
359            | Operation::Gelu
360            | Operation::Silu => true,
361            // Matrix operations can benefit from vectorization
362            Operation::MatMul | Operation::BatchMatMul => true,
363            // Reduction operations can be vectorized
364            Operation::Sum { .. }
365            | Operation::Mean { .. }
366            | Operation::Max { .. }
367            | Operation::Min { .. } => true,
368            // Convolutions are vectorizable
369            Operation::Conv2d(_) => true,
370            // Other operations are typically not vectorizable
371            _ => false,
372        }
373    }
374
375    /// Check if this operation accesses memory (for cache optimization)
376    pub fn has_memory_access(&self) -> bool {
377        match &self.operation {
378            // Operations that don't access external memory
379            Operation::Input | Operation::Parameter(_) | Operation::Constant(_) => false,
380            // Control flow operations typically don't access memory directly
381            Operation::Break | Operation::Continue | Operation::Return(_) => false,
382            // All computation operations access memory
383            _ => true,
384        }
385    }
386
387    /// Estimate the working set size (bytes) for memory access patterns
388    pub fn estimate_working_set_size(&self) -> usize {
389        let mut working_set = 0;
390
391        // Input working set (data being read)
392        for shape_opt in &self.input_shapes {
393            if let Some(shape) = shape_opt {
394                let elements = shape.dims().iter().product::<usize>();
395                // Assume each element is at least 4 bytes (f32)
396                working_set += elements * 4;
397            }
398        }
399
400        // Output working set (data being written)
401        for shape_opt in &self.output_shapes {
402            if let Some(shape) = shape_opt {
403                let elements = shape.dims().iter().product::<usize>();
404                working_set += elements * 4;
405            }
406        }
407
408        // Operation-specific working set adjustments
409        match &self.operation {
410            Operation::MatMul | Operation::BatchMatMul => {
411                // Matrix multiplication has intermediate results
412                working_set * 2
413            }
414            Operation::Conv2d(_) => {
415                // Convolution may need workspace for im2col
416                working_set * 3
417            }
418            _ => working_set,
419        }
420    }
421}
422
423/// Categories of operations for optimization and analysis
424#[derive(Debug, Clone, PartialEq, Eq)]
425pub enum OperationCategory {
426    ElementWise,
427    LinearAlgebra,
428    NeuralNetwork,
429    Reduction,
430    ShapeManipulation,
431    ControlFlow,
432    Input,
433    Other,
434}
435
436/// Computation graph representing a neural network or computation
437#[derive(Debug, Clone)]
438pub struct ComputationGraph {
439    /// Internal graph representation
440    pub(crate) graph: DiGraph<Node, Edge>,
441
442    /// Input nodes
443    pub inputs: Vec<NodeId>,
444
445    /// Output nodes
446    pub outputs: Vec<NodeId>,
447
448    /// Metadata
449    pub metadata: GraphMetadata,
450}
451
452impl ComputationGraph {
453    /// Create a new empty computation graph
454    pub fn new() -> Self {
455        Self {
456            graph: DiGraph::new(),
457            inputs: Vec::new(),
458            outputs: Vec::new(),
459            metadata: GraphMetadata::default(),
460        }
461    }
462
463    /// Add a node to the graph
464    pub fn add_node(&mut self, node: Node) -> NodeId {
465        self.graph.add_node(node)
466    }
467
468    /// Add an edge between nodes
469    pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
470        self.graph.add_edge(from, to, edge);
471    }
472
473    /// Mark a node as input
474    pub fn add_input(&mut self, node: NodeId) {
475        if !self.inputs.contains(&node) {
476            self.inputs.push(node);
477        }
478    }
479
480    /// Mark a node as output
481    pub fn add_output(&mut self, node: NodeId) {
482        if !self.outputs.contains(&node) {
483            self.outputs.push(node);
484        }
485    }
486
487    /// Get all nodes
488    pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &Node)> {
489        self.graph
490            .node_indices()
491            .map(move |idx| (idx, &self.graph[idx]))
492    }
493
494    /// Get all edges
495    pub fn edges(&self) -> impl Iterator<Item = (NodeId, NodeId, &Edge)> + '_ {
496        self.graph.edge_indices().map(move |idx| {
497            let (src, dst) = self
498                .graph
499                .edge_endpoints(idx)
500                .expect("edge index should be valid");
501            (src, dst, &self.graph[idx])
502        })
503    }
504
505    /// Get node by ID
506    pub fn get_node(&self, id: NodeId) -> Option<&Node> {
507        self.graph.node_weight(id)
508    }
509
510    /// Get mutable node by ID
511    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
512        self.graph.node_weight_mut(id)
513    }
514
515    /// Get node inputs
516    pub fn get_node_inputs(&self, id: NodeId) -> Vec<NodeId> {
517        self.graph
518            .neighbors_directed(id, Direction::Incoming)
519            .collect()
520    }
521
522    /// Get node outputs
523    pub fn get_node_outputs(&self, id: NodeId) -> Vec<NodeId> {
524        self.graph
525            .neighbors_directed(id, Direction::Outgoing)
526            .collect()
527    }
528
529    /// Get incoming edges for a node
530    pub fn incoming_edges(&self, id: NodeId) -> Vec<(NodeId, NodeId, &Edge)> {
531        self.graph
532            .edges_directed(id, Direction::Incoming)
533            .map(|edge_ref| (edge_ref.source(), edge_ref.target(), edge_ref.weight()))
534            .collect()
535    }
536
537    /// Get outgoing edges for a node
538    pub fn outgoing_edges(&self, id: NodeId) -> Vec<(NodeId, NodeId, &Edge)> {
539        self.graph
540            .edges_directed(id, Direction::Outgoing)
541            .map(|edge_ref| (edge_ref.source(), edge_ref.target(), edge_ref.weight()))
542            .collect()
543    }
544
545    /// Remove a node from the graph
546    pub fn remove_node(&mut self, id: NodeId) -> Option<Node> {
547        // Remove from inputs/outputs
548        self.inputs.retain(|&x| x != id);
549        self.outputs.retain(|&x| x != id);
550
551        self.graph.remove_node(id)
552    }
553
554    /// Remove an edge from the graph
555    pub fn remove_edge(&mut self, from: NodeId, to: NodeId) -> bool {
556        if let Some(edge_id) = self.graph.find_edge(from, to) {
557            self.graph.remove_edge(edge_id).is_some()
558        } else {
559            false
560        }
561    }
562
563    /// Get number of nodes
564    pub fn node_count(&self) -> usize {
565        self.graph.node_count()
566    }
567
568    /// Get number of edges
569    pub fn edge_count(&self) -> usize {
570        self.graph.edge_count()
571    }
572
573    /// Check if graph is empty
574    pub fn is_empty(&self) -> bool {
575        self.graph.node_count() == 0
576    }
577
578    /// Validate the graph structure
579    pub fn validate(&self) -> JitResult<()> {
580        // Check that all input/output node IDs exist
581        for &input_id in &self.inputs {
582            if self.graph.node_weight(input_id).is_none() {
583                return Err(JitError::GraphError(format!(
584                    "Input node {:?} does not exist in graph",
585                    input_id
586                )));
587            }
588        }
589
590        for &output_id in &self.outputs {
591            if self.graph.node_weight(output_id).is_none() {
592                return Err(JitError::GraphError(format!(
593                    "Output node {:?} does not exist in graph",
594                    output_id
595                )));
596            }
597        }
598
599        // Check for cycles in non-control-flow subgraph
600        self.validate_acyclic()?;
601
602        Ok(())
603    }
604
605    /// Check that the graph is acyclic (ignoring control flow edges)
606    fn validate_acyclic(&self) -> JitResult<()> {
607        use petgraph::algo::is_cyclic_directed;
608
609        if is_cyclic_directed(&self.graph) {
610            return Err(JitError::GraphError("Graph contains cycles".to_string()));
611        }
612
613        Ok(())
614    }
615
616    /// Get topological ordering of nodes
617    pub fn topological_sort(&self) -> JitResult<Vec<NodeId>> {
618        use petgraph::algo::toposort;
619
620        toposort(&self.graph, None)
621            .map_err(|_| JitError::GraphError("Graph contains cycles".to_string()))
622    }
623
624    /// Clone with only specified nodes
625    pub fn subgraph(&self, node_ids: &[NodeId]) -> JitResult<ComputationGraph> {
626        let mut new_graph = ComputationGraph::new();
627        let mut node_mapping = HashMap::new();
628
629        // Add nodes
630        for &node_id in node_ids {
631            if let Some(node) = self.get_node(node_id) {
632                let new_id = new_graph.add_node(node.clone());
633                node_mapping.insert(node_id, new_id);
634            } else {
635                return Err(JitError::GraphError(format!(
636                    "Node {:?} not found in original graph",
637                    node_id
638                )));
639            }
640        }
641
642        // Add edges between included nodes
643        for &src_id in node_ids {
644            for &dst_id in node_ids {
645                if let Some(edge_ref) = self.graph.find_edge(src_id, dst_id) {
646                    let edge = self.graph.edge_weight(edge_ref).expect("edge should exist");
647                    let new_src = node_mapping[&src_id];
648                    let new_dst = node_mapping[&dst_id];
649                    new_graph.add_edge(new_src, new_dst, edge.clone());
650                }
651            }
652        }
653
654        // Update inputs/outputs
655        for &input_id in &self.inputs {
656            if let Some(&new_id) = node_mapping.get(&input_id) {
657                new_graph.add_input(new_id);
658            }
659        }
660
661        for &output_id in &self.outputs {
662            if let Some(&new_id) = node_mapping.get(&output_id) {
663                new_graph.add_output(new_id);
664            }
665        }
666
667        new_graph.metadata = self.metadata.clone();
668
669        Ok(new_graph)
670    }
671
672    /// Get strongly connected components
673    pub fn strongly_connected_components(&self) -> Vec<Vec<NodeId>> {
674        use petgraph::algo::tarjan_scc;
675        tarjan_scc(&self.graph)
676    }
677
678    /// Get memory usage estimate in bytes
679    pub fn memory_estimate(&self) -> usize {
680        self.graph
681            .node_weights()
682            .map(|node| node.memory_estimate())
683            .sum()
684    }
685
686    /// Get computational complexity estimate (FLOPs)
687    pub fn complexity_estimate(&self) -> usize {
688        self.graph
689            .node_weights()
690            .map(|node| node.complexity_estimate())
691            .sum()
692    }
693
694    /// Get predecessors of a node (compatibility method)
695    pub fn predecessors(&self, node_id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
696        self.graph.neighbors_directed(node_id, Direction::Incoming)
697    }
698
699    /// Get successors of a node (compatibility method)
700    pub fn successors(&self, node_id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
701        self.graph.neighbors_directed(node_id, Direction::Outgoing)
702    }
703
704    /// Get node by ID (compatibility method)
705    pub fn node(&self, id: NodeId) -> Option<&Node> {
706        self.get_node(id)
707    }
708
709    /// Get mutable node by ID (compatibility method)
710    pub fn node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
711        self.get_node_mut(id)
712    }
713
714    /// Get directed edges for a node (compatibility method)
715    pub fn edges_directed(
716        &self,
717        node_id: NodeId,
718        direction: Direction,
719    ) -> impl Iterator<Item = petgraph::graph::EdgeReference<'_, Edge>> {
720        self.graph.edges_directed(node_id, direction)
721    }
722
723    /// Check if the graph is acyclic (compatibility method)
724    pub fn is_acyclic(&self) -> bool {
725        use petgraph::algo::is_cyclic_directed;
726        !is_cyclic_directed(&self.graph)
727    }
728
729    /// Replace a node with one of its inputs (for constant folding and branch elimination)
730    ///
731    /// This operation:
732    /// 1. Redirects all edges coming into `node_id` to `replacement_id`
733    /// 2. Redirects all edges going out of `node_id` to come from `replacement_id`
734    /// 3. Removes `node_id` from the graph
735    ///
736    /// # Arguments
737    ///
738    /// * `node_id` - The node to replace
739    /// * `replacement_id` - The input node that will replace it
740    ///
741    /// # Returns
742    ///
743    /// * `Ok(())` if successful
744    /// * `Err(JitError)` if the replacement would create an invalid graph
745    pub fn replace_node_with_input(
746        &mut self,
747        node_id: NodeId,
748        replacement_id: NodeId,
749    ) -> crate::JitResult<()> {
750        // Validate that replacement_id is actually an input to node_id
751        let is_predecessor = self
752            .predecessors(node_id)
753            .any(|pred| pred == replacement_id);
754
755        if !is_predecessor {
756            return Err(crate::JitError::CompilationError(format!(
757                "Node {:?} is not a predecessor of node {:?}",
758                replacement_id, node_id
759            )));
760        }
761
762        // Collect all successor edges before modification
763        let successors: Vec<(NodeId, Edge)> = self
764            .graph
765            .edges_directed(node_id, Direction::Outgoing)
766            .map(|edge_ref| (edge_ref.target(), edge_ref.weight().clone()))
767            .collect();
768
769        // Redirect all outgoing edges to come from replacement_id instead
770        for (successor_id, edge) in successors {
771            self.graph.add_edge(replacement_id, successor_id, edge);
772        }
773
774        // Update outputs list if node_id was an output
775        if let Some(pos) = self.outputs.iter().position(|&id| id == node_id) {
776            self.outputs[pos] = replacement_id;
777        }
778
779        // Remove the replaced node (this also cleans up inputs/outputs lists)
780        self.remove_node(node_id);
781
782        Ok(())
783    }
784
785    /// Replace a node with a sequence of nodes (for loop unrolling and macro expansion)
786    ///
787    /// This operation:
788    /// 1. Inserts the sequence of nodes into the graph
789    /// 2. Connects the first node in the sequence to the inputs of `node_id`
790    /// 3. Connects the last node in the sequence to the outputs of `node_id`
791    /// 4. Removes `node_id` from the graph
792    ///
793    /// # Arguments
794    ///
795    /// * `node_id` - The node to replace
796    /// * `sequence` - The sequence of nodes to insert (must not be empty)
797    ///
798    /// # Returns
799    ///
800    /// * `Ok(())` if successful
801    /// * `Err(JitError)` if the sequence is empty or would create an invalid graph
802    pub fn replace_node_with_sequence(
803        &mut self,
804        node_id: NodeId,
805        sequence: &[Node],
806    ) -> crate::JitResult<()> {
807        if sequence.is_empty() {
808            return Err(crate::JitError::CompilationError(
809                "Cannot replace node with empty sequence".to_string(),
810            ));
811        }
812
813        // Add all nodes in the sequence
814        let sequence_ids: Vec<NodeId> = sequence
815            .iter()
816            .map(|node| self.graph.add_node(node.clone()))
817            .collect();
818
819        let first_id = sequence_ids[0];
820        let last_id = *sequence_ids.last().expect("sequence should not be empty");
821
822        // Connect nodes in the sequence to each other
823        for window in sequence_ids.windows(2) {
824            let edge = Edge {
825                src_output: 0,
826                dst_input: 0,
827            };
828            self.graph.add_edge(window[0], window[1], edge);
829        }
830
831        // Collect predecessor edges before modification
832        let predecessors: Vec<(NodeId, Edge)> = self
833            .graph
834            .edges_directed(node_id, Direction::Incoming)
835            .map(|edge_ref| (edge_ref.source(), edge_ref.weight().clone()))
836            .collect();
837
838        // Redirect incoming edges to the first node in the sequence
839        for (pred_id, edge) in predecessors {
840            self.graph.add_edge(pred_id, first_id, edge);
841        }
842
843        // Collect successor edges before modification
844        let successors: Vec<(NodeId, Edge)> = self
845            .graph
846            .edges_directed(node_id, Direction::Outgoing)
847            .map(|edge_ref| (edge_ref.target(), edge_ref.weight().clone()))
848            .collect();
849
850        // Redirect outgoing edges to come from the last node in the sequence
851        for (succ_id, edge) in successors {
852            self.graph.add_edge(last_id, succ_id, edge);
853        }
854
855        // Update inputs list if node_id was an input
856        if let Some(pos) = self.inputs.iter().position(|&id| id == node_id) {
857            self.inputs[pos] = first_id;
858        }
859
860        // Update outputs list if node_id was an output
861        if let Some(pos) = self.outputs.iter().position(|&id| id == node_id) {
862            self.outputs[pos] = last_id;
863        }
864
865        // Remove the replaced node (this also cleans up inputs/outputs lists)
866        self.remove_node(node_id);
867
868        Ok(())
869    }
870}
871
872impl Default for ComputationGraph {
873    fn default() -> Self {
874        Self::new()
875    }
876}
877
878/// Utility function to create a Shape from a slice of dimensions
879pub fn shape_from_slice(dims: &[usize]) -> Shape {
880    Shape::new(dims.to_vec())
881}