Skip to main content

tenflowers_core/graph/
node_edge_ops.rs

1//! Node and Edge Management Operations
2//!
3//! This module provides functionality for managing individual nodes and edges,
4//! including removal, replacement, and redirection operations.
5
6use super::core::*;
7use crate::error::TensorError;
8
9impl Graph {
10    /// Remove a node and all its associated edges
11    pub fn remove_node(&mut self, node_id: NodeId) -> Result<(), TensorError> {
12        let node = self
13            .nodes
14            .get(&node_id)
15            .ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
16
17        // Collect all edges to remove
18        let mut edges_to_remove = Vec::new();
19        edges_to_remove.extend(node.inputs.iter());
20        edges_to_remove.extend(node.outputs.iter());
21
22        // Remove the node from name mapping
23        self.name_to_node.remove(&node.name);
24
25        // Remove all associated edges
26        for &edge_id in &edges_to_remove {
27            if let Some(edge) = self.edges.remove(&edge_id) {
28                // Update the other node's edge lists
29                if edge.from_node != node_id {
30                    if let Some(from_node) = self.nodes.get_mut(&edge.from_node) {
31                        from_node.outputs.retain(|&id| id != edge_id);
32                    }
33                }
34                if edge.to_node != node_id {
35                    if let Some(to_node) = self.nodes.get_mut(&edge.to_node) {
36                        to_node.inputs.retain(|&id| id != edge_id);
37                    }
38                }
39            }
40        }
41
42        // Remove the node
43        self.nodes.remove(&node_id);
44
45        self.topological_order = None; // Invalidate cached order
46        self.version += 1;
47
48        Ok(())
49    }
50
51    /// Remove an edge
52    pub fn remove_edge(&mut self, edge_id: EdgeId) -> Result<(), TensorError> {
53        let edge = self
54            .edges
55            .remove(&edge_id)
56            .ok_or_else(|| TensorError::invalid_argument(format!("Edge {} not found", edge_id)))?;
57
58        // Update node edge lists
59        if let Some(from_node) = self.nodes.get_mut(&edge.from_node) {
60            from_node.outputs.retain(|&id| id != edge_id);
61        }
62        if let Some(to_node) = self.nodes.get_mut(&edge.to_node) {
63            to_node.inputs.retain(|&id| id != edge_id);
64        }
65
66        self.topological_order = None; // Invalidate cached order
67        self.version += 1;
68
69        Ok(())
70    }
71
72    /// Replace a node with a constant value
73    pub fn replace_with_constant(
74        &mut self,
75        node_id: NodeId,
76        constant_value: crate::tensor::Tensor<f32>,
77    ) -> Result<(), TensorError> {
78        let node = self
79            .nodes
80            .get_mut(&node_id)
81            .ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
82
83        // Update node type to constant
84        node.op_type = NodeType::Constant;
85
86        // Add the constant value as an attribute
87        node.attributes
88            .insert("value".to_string(), AttributeValue::Tensor(constant_value));
89
90        // Remove all input edges since constants don't have inputs
91        let input_edges: Vec<EdgeId> = node.inputs.clone();
92        node.inputs.clear();
93
94        for edge_id in input_edges {
95            self.remove_edge(edge_id)?;
96        }
97
98        self.version += 1;
99
100        Ok(())
101    }
102
103    /// Redirect all outputs from one node to another
104    pub fn redirect_node_outputs(
105        &mut self,
106        from_node: NodeId,
107        to_node: NodeId,
108    ) -> Result<usize, TensorError> {
109        if !self.nodes.contains_key(&from_node) {
110            return Err(TensorError::invalid_argument(format!(
111                "Source node {} not found",
112                from_node
113            )));
114        }
115        if !self.nodes.contains_key(&to_node) {
116            return Err(TensorError::invalid_argument(format!(
117                "Target node {} not found",
118                to_node
119            )));
120        }
121
122        let output_edges: Vec<EdgeId> = self
123            .nodes
124            .get(&from_node)
125            .expect("Source node must exist after validation")
126            .outputs
127            .clone();
128
129        let mut redirected_count = 0;
130
131        for edge_id in output_edges {
132            if let Some(edge) = self.edges.get_mut(&edge_id) {
133                edge.from_node = to_node;
134                redirected_count += 1;
135
136                // Update node edge lists
137                self.nodes
138                    .get_mut(&to_node)
139                    .expect("Target node must exist after validation")
140                    .outputs
141                    .push(edge_id);
142            }
143        }
144
145        // Clear output edges from the original node
146        self.nodes
147            .get_mut(&from_node)
148            .expect("Source node must exist after validation")
149            .outputs
150            .clear();
151
152        if redirected_count > 0 {
153            self.topological_order = None; // Invalidate cached order
154            self.version += 1;
155        }
156
157        Ok(redirected_count)
158    }
159}