tenflowers_core/graph/
node_edge_ops.rs1use super::core::*;
7use crate::error::TensorError;
8
9impl Graph {
10 pub fn remove_node(&mut self, node_id: NodeId) -> Result<(), TensorError> {
12 let node = self
13 .nodes
14 .get(&node_id)
15 .ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
16
17 let mut edges_to_remove = Vec::new();
19 edges_to_remove.extend(node.inputs.iter());
20 edges_to_remove.extend(node.outputs.iter());
21
22 self.name_to_node.remove(&node.name);
24
25 for &edge_id in &edges_to_remove {
27 if let Some(edge) = self.edges.remove(&edge_id) {
28 if edge.from_node != node_id {
30 if let Some(from_node) = self.nodes.get_mut(&edge.from_node) {
31 from_node.outputs.retain(|&id| id != edge_id);
32 }
33 }
34 if edge.to_node != node_id {
35 if let Some(to_node) = self.nodes.get_mut(&edge.to_node) {
36 to_node.inputs.retain(|&id| id != edge_id);
37 }
38 }
39 }
40 }
41
42 self.nodes.remove(&node_id);
44
45 self.topological_order = None; self.version += 1;
47
48 Ok(())
49 }
50
51 pub fn remove_edge(&mut self, edge_id: EdgeId) -> Result<(), TensorError> {
53 let edge = self
54 .edges
55 .remove(&edge_id)
56 .ok_or_else(|| TensorError::invalid_argument(format!("Edge {} not found", edge_id)))?;
57
58 if let Some(from_node) = self.nodes.get_mut(&edge.from_node) {
60 from_node.outputs.retain(|&id| id != edge_id);
61 }
62 if let Some(to_node) = self.nodes.get_mut(&edge.to_node) {
63 to_node.inputs.retain(|&id| id != edge_id);
64 }
65
66 self.topological_order = None; self.version += 1;
68
69 Ok(())
70 }
71
72 pub fn replace_with_constant(
74 &mut self,
75 node_id: NodeId,
76 constant_value: crate::tensor::Tensor<f32>,
77 ) -> Result<(), TensorError> {
78 let node = self
79 .nodes
80 .get_mut(&node_id)
81 .ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
82
83 node.op_type = NodeType::Constant;
85
86 node.attributes
88 .insert("value".to_string(), AttributeValue::Tensor(constant_value));
89
90 let input_edges: Vec<EdgeId> = node.inputs.clone();
92 node.inputs.clear();
93
94 for edge_id in input_edges {
95 self.remove_edge(edge_id)?;
96 }
97
98 self.version += 1;
99
100 Ok(())
101 }
102
103 pub fn redirect_node_outputs(
105 &mut self,
106 from_node: NodeId,
107 to_node: NodeId,
108 ) -> Result<usize, TensorError> {
109 if !self.nodes.contains_key(&from_node) {
110 return Err(TensorError::invalid_argument(format!(
111 "Source node {} not found",
112 from_node
113 )));
114 }
115 if !self.nodes.contains_key(&to_node) {
116 return Err(TensorError::invalid_argument(format!(
117 "Target node {} not found",
118 to_node
119 )));
120 }
121
122 let output_edges: Vec<EdgeId> = self
123 .nodes
124 .get(&from_node)
125 .expect("Source node must exist after validation")
126 .outputs
127 .clone();
128
129 let mut redirected_count = 0;
130
131 for edge_id in output_edges {
132 if let Some(edge) = self.edges.get_mut(&edge_id) {
133 edge.from_node = to_node;
134 redirected_count += 1;
135
136 self.nodes
138 .get_mut(&to_node)
139 .expect("Target node must exist after validation")
140 .outputs
141 .push(edge_id);
142 }
143 }
144
145 self.nodes
147 .get_mut(&from_node)
148 .expect("Source node must exist after validation")
149 .outputs
150 .clear();
151
152 if redirected_count > 0 {
153 self.topological_order = None; self.version += 1;
155 }
156
157 Ok(redirected_count)
158 }
159}