1use super::core::*;
7use crate::{device::Device, error::TensorError};
8use std::collections::HashMap;
9
10impl Graph {
11 pub fn extend_with_graph(
13 &mut self,
14 other: &Graph,
15 node_prefix: Option<&str>,
16 ) -> Result<HashMap<NodeId, NodeId>, TensorError> {
17 let mut id_mapping = HashMap::new();
18 let prefix = node_prefix.unwrap_or("");
19
20 for node in other.nodes.values() {
22 let new_name = if prefix.is_empty() {
23 node.name.clone()
24 } else {
25 format!("{prefix}_{}", node.name)
26 };
27
28 let new_id = self.add_node(
29 new_name,
30 node.op_type.clone(),
31 node.device,
32 node.attributes.clone(),
33 )?;
34 id_mapping.insert(node.id, new_id);
35 }
36
37 for edge in other.edges.values() {
39 let new_from = *id_mapping
40 .get(&edge.from_node)
41 .expect("Node ID must exist in mapping after insertion");
42 let new_to = *id_mapping
43 .get(&edge.to_node)
44 .expect("Node ID must exist in mapping after insertion");
45
46 self.add_edge(
47 new_from,
48 new_to,
49 edge.from_output,
50 edge.to_input,
51 edge.dtype,
52 edge.shape.clone(),
53 edge.is_control,
54 )?;
55 }
56
57 Ok(id_mapping)
58 }
59
60 pub fn integrate_subgraph(
62 &mut self,
63 subgraph: &Graph,
64 input_connections: &[(NodeId, usize, NodeId, usize)], output_connections: &[(NodeId, usize, NodeId, usize)], node_prefix: Option<&str>,
67 ) -> Result<HashMap<NodeId, NodeId>, TensorError> {
68 let id_mapping = self.extend_with_graph(subgraph, node_prefix)?;
70
71 for &(external_node, output_idx, subgraph_input, input_idx) in input_connections {
73 if !self.nodes.contains_key(&external_node) {
74 return Err(TensorError::invalid_argument(format!(
75 "External node {} not found",
76 external_node
77 )));
78 }
79
80 let mapped_subgraph_node = *id_mapping.get(&subgraph_input).ok_or_else(|| {
81 TensorError::invalid_argument(format!(
82 "Subgraph input node {} not found",
83 subgraph_input
84 ))
85 })?;
86
87 self.add_edge(
90 external_node,
91 mapped_subgraph_node,
92 output_idx,
93 input_idx,
94 crate::dtype::DType::Float32, crate::shape::Shape::new(vec![]),
96 false,
97 )?;
98 }
99
100 for &(subgraph_output, output_idx, external_node, input_idx) in output_connections {
102 if !self.nodes.contains_key(&external_node) {
103 return Err(TensorError::invalid_argument(format!(
104 "External node {} not found",
105 external_node
106 )));
107 }
108
109 let mapped_subgraph_node = *id_mapping.get(&subgraph_output).ok_or_else(|| {
110 TensorError::invalid_argument(format!(
111 "Subgraph output node {} not found",
112 subgraph_output
113 ))
114 })?;
115
116 self.add_edge(
118 mapped_subgraph_node,
119 external_node,
120 output_idx,
121 input_idx,
122 crate::dtype::DType::Float32, crate::shape::Shape::new(vec![]),
124 false,
125 )?;
126 }
127
128 Ok(id_mapping)
129 }
130
131 pub fn merge_graphs(graphs: &[&Graph]) -> Result<Graph, TensorError> {
133 let mut merged = Graph::new();
134
135 for (i, graph) in graphs.iter().enumerate() {
136 let prefix = format!("graph_{}", i);
137 merged.extend_with_graph(graph, Some(&prefix))?;
138 }
139
140 Ok(merged)
141 }
142
143 pub fn add_node_auto_name(
145 &mut self,
146 base_name: &str,
147 op_type: NodeType,
148 device: Device,
149 attributes: HashMap<String, AttributeValue>,
150 ) -> Result<NodeId, TensorError> {
151 let mut counter = 0;
152 let mut name = base_name.to_string();
153
154 while self.name_to_node.contains_key(&name) {
155 counter += 1;
156 name = format!("{}_{}", base_name, counter);
157 }
158
159 self.add_node(name, op_type, device, attributes)
160 }
161
162 pub fn add_operation_subgraph(
164 &mut self,
165 op_name: &str,
166 inputs: &[NodeId],
167 output_shapes: &[crate::shape::Shape],
168 output_dtypes: &[crate::dtype::DType],
169 device: Device,
170 attributes: HashMap<String, AttributeValue>,
171 ) -> Result<Vec<NodeId>, TensorError> {
172 let op_node = self.add_node_auto_name(
174 op_name,
175 NodeType::Operation(op_name.to_string()),
176 device,
177 attributes,
178 )?;
179
180 for (input_idx, &input_node) in inputs.iter().enumerate() {
182 if !self.nodes.contains_key(&input_node) {
183 return Err(TensorError::invalid_argument(format!(
184 "Input node {} not found",
185 input_node
186 )));
187 }
188
189 self.add_edge(
190 input_node,
191 op_node,
192 0, input_idx,
194 crate::dtype::DType::Float32, crate::shape::Shape::new(vec![]),
196 false,
197 )?;
198 }
199
200 let mut output_nodes = vec![op_node];
202 if output_shapes.len() > 1 {
203 for (output_idx, (shape, dtype)) in output_shapes
204 .iter()
205 .zip(output_dtypes.iter())
206 .enumerate()
207 .skip(1)
208 {
209 let output_node_name = format!("{}_output_{}", op_name, output_idx);
210 let output_node = self.add_node(
211 output_node_name,
212 NodeType::Operation("Identity".to_string()),
213 device,
214 HashMap::new(),
215 )?;
216
217 self.add_edge(
218 op_node,
219 output_node,
220 output_idx,
221 0,
222 *dtype,
223 shape.clone(),
224 false,
225 )?;
226
227 output_nodes.push(output_node);
228 }
229 }
230
231 Ok(output_nodes)
232 }
233
234 pub fn insert_node_between(
236 &mut self,
237 from_node: NodeId,
238 to_node: NodeId,
239 new_node_name: String,
240 new_node_type: NodeType,
241 device: Device,
242 attributes: HashMap<String, AttributeValue>,
243 ) -> Result<NodeId, TensorError> {
244 let edge_to_replace = self
246 .edges
247 .values()
248 .find(|edge| edge.from_node == from_node && edge.to_node == to_node && !edge.is_control)
249 .cloned();
250
251 let edge = edge_to_replace.ok_or_else(|| {
252 TensorError::invalid_argument(format!(
253 "No data edge found between nodes {} and {}",
254 from_node, to_node
255 ))
256 })?;
257
258 let new_node = self.add_node(new_node_name, new_node_type, device, attributes)?;
260
261 self.remove_edge(edge.id)?;
263
264 self.add_edge(
266 from_node,
267 new_node,
268 edge.from_output,
269 0,
270 edge.dtype,
271 edge.shape.clone(),
272 false,
273 )?;
274
275 self.add_edge(
276 new_node,
277 to_node,
278 0,
279 edge.to_input,
280 edge.dtype,
281 edge.shape,
282 false,
283 )?;
284
285 Ok(new_node)
286 }
287
288 pub fn replace_node_with_subgraph(
290 &mut self,
291 node_to_replace: NodeId,
292 replacement_graph: &Graph,
293 input_mapping: &HashMap<usize, NodeId>, output_mapping: &HashMap<usize, NodeId>, ) -> Result<HashMap<NodeId, NodeId>, TensorError> {
296 let node = self
297 .nodes
298 .get(&node_to_replace)
299 .ok_or_else(|| {
300 TensorError::invalid_argument(format!("Node {} not found", node_to_replace))
301 })?
302 .clone();
303
304 let incoming_edges: Vec<_> = node
306 .inputs
307 .iter()
308 .filter_map(|&edge_id| self.edges.get(&edge_id))
309 .cloned()
310 .collect();
311
312 let outgoing_edges: Vec<_> = node
313 .outputs
314 .iter()
315 .filter_map(|&edge_id| self.edges.get(&edge_id))
316 .cloned()
317 .collect();
318
319 self.remove_node(node_to_replace)?;
321
322 let id_mapping = self.extend_with_graph(replacement_graph, Some(&node.name))?;
324
325 for edge in incoming_edges {
327 if let Some(&replacement_input) = input_mapping.get(&edge.to_input) {
328 if let Some(&mapped_node) = id_mapping.get(&replacement_input) {
329 self.add_edge(
330 edge.from_node,
331 mapped_node,
332 edge.from_output,
333 0, edge.dtype,
335 edge.shape,
336 edge.is_control,
337 )?;
338 }
339 }
340 }
341
342 for edge in outgoing_edges {
344 if let Some(&replacement_output) = output_mapping.get(&edge.from_output) {
345 if let Some(&mapped_node) = id_mapping.get(&replacement_output) {
346 self.add_edge(
347 mapped_node,
348 edge.to_node,
349 0, edge.to_input,
351 edge.dtype,
352 edge.shape,
353 edge.is_control,
354 )?;
355 }
356 }
357 }
358
359 Ok(id_mapping)
360 }
361}