Skip to main content

tenflowers_core/graph/
manipulation.rs

1//! Graph Manipulation Operations
2//!
3//! This module provides functionality for manipulating and transforming graphs,
4//! including extending, merging, and replacing graph structures.
5
6use super::core::*;
7use crate::{device::Device, error::TensorError};
8use std::collections::HashMap;
9
10impl Graph {
11    /// Extend this graph with another graph
12    pub fn extend_with_graph(
13        &mut self,
14        other: &Graph,
15        node_prefix: Option<&str>,
16    ) -> Result<HashMap<NodeId, NodeId>, TensorError> {
17        let mut id_mapping = HashMap::new();
18        let prefix = node_prefix.unwrap_or("");
19
20        // Add all nodes from the other graph
21        for node in other.nodes.values() {
22            let new_name = if prefix.is_empty() {
23                node.name.clone()
24            } else {
25                format!("{prefix}_{}", node.name)
26            };
27
28            let new_id = self.add_node(
29                new_name,
30                node.op_type.clone(),
31                node.device,
32                node.attributes.clone(),
33            )?;
34            id_mapping.insert(node.id, new_id);
35        }
36
37        // Add all edges from the other graph
38        for edge in other.edges.values() {
39            let new_from = *id_mapping
40                .get(&edge.from_node)
41                .expect("Node ID must exist in mapping after insertion");
42            let new_to = *id_mapping
43                .get(&edge.to_node)
44                .expect("Node ID must exist in mapping after insertion");
45
46            self.add_edge(
47                new_from,
48                new_to,
49                edge.from_output,
50                edge.to_input,
51                edge.dtype,
52                edge.shape.clone(),
53                edge.is_control,
54            )?;
55        }
56
57        Ok(id_mapping)
58    }
59
60    /// Integrate a subgraph into this graph at specific connection points
61    pub fn integrate_subgraph(
62        &mut self,
63        subgraph: &Graph,
64        input_connections: &[(NodeId, usize, NodeId, usize)], // (external_node, output_idx, subgraph_input, input_idx)
65        output_connections: &[(NodeId, usize, NodeId, usize)], // (subgraph_output, output_idx, external_node, input_idx)
66        node_prefix: Option<&str>,
67    ) -> Result<HashMap<NodeId, NodeId>, TensorError> {
68        // First extend the graph with the subgraph
69        let id_mapping = self.extend_with_graph(subgraph, node_prefix)?;
70
71        // Create input connections
72        for &(external_node, output_idx, subgraph_input, input_idx) in input_connections {
73            if !self.nodes.contains_key(&external_node) {
74                return Err(TensorError::invalid_argument(format!(
75                    "External node {} not found",
76                    external_node
77                )));
78            }
79
80            let mapped_subgraph_node = *id_mapping.get(&subgraph_input).ok_or_else(|| {
81                TensorError::invalid_argument(format!(
82                    "Subgraph input node {} not found",
83                    subgraph_input
84                ))
85            })?;
86
87            // Create edge from external node to subgraph input
88            // Note: This is a simplified version - in practice, we'd need to infer types and shapes
89            self.add_edge(
90                external_node,
91                mapped_subgraph_node,
92                output_idx,
93                input_idx,
94                crate::dtype::DType::Float32, // Default type
95                crate::shape::Shape::new(vec![]),
96                false,
97            )?;
98        }
99
100        // Create output connections
101        for &(subgraph_output, output_idx, external_node, input_idx) in output_connections {
102            if !self.nodes.contains_key(&external_node) {
103                return Err(TensorError::invalid_argument(format!(
104                    "External node {} not found",
105                    external_node
106                )));
107            }
108
109            let mapped_subgraph_node = *id_mapping.get(&subgraph_output).ok_or_else(|| {
110                TensorError::invalid_argument(format!(
111                    "Subgraph output node {} not found",
112                    subgraph_output
113                ))
114            })?;
115
116            // Create edge from subgraph output to external node
117            self.add_edge(
118                mapped_subgraph_node,
119                external_node,
120                output_idx,
121                input_idx,
122                crate::dtype::DType::Float32, // Default type
123                crate::shape::Shape::new(vec![]),
124                false,
125            )?;
126        }
127
128        Ok(id_mapping)
129    }
130
131    /// Merge multiple graphs into a single graph
132    pub fn merge_graphs(graphs: &[&Graph]) -> Result<Graph, TensorError> {
133        let mut merged = Graph::new();
134
135        for (i, graph) in graphs.iter().enumerate() {
136            let prefix = format!("graph_{}", i);
137            merged.extend_with_graph(graph, Some(&prefix))?;
138        }
139
140        Ok(merged)
141    }
142
143    /// Add a node with automatically generated unique name
144    pub fn add_node_auto_name(
145        &mut self,
146        base_name: &str,
147        op_type: NodeType,
148        device: Device,
149        attributes: HashMap<String, AttributeValue>,
150    ) -> Result<NodeId, TensorError> {
151        let mut counter = 0;
152        let mut name = base_name.to_string();
153
154        while self.name_to_node.contains_key(&name) {
155            counter += 1;
156            name = format!("{}_{}", base_name, counter);
157        }
158
159        self.add_node(name, op_type, device, attributes)
160    }
161
162    /// Add a complete operation subgraph (operation with inputs and outputs)
163    pub fn add_operation_subgraph(
164        &mut self,
165        op_name: &str,
166        inputs: &[NodeId],
167        output_shapes: &[crate::shape::Shape],
168        output_dtypes: &[crate::dtype::DType],
169        device: Device,
170        attributes: HashMap<String, AttributeValue>,
171    ) -> Result<Vec<NodeId>, TensorError> {
172        // Create the operation node
173        let op_node = self.add_node_auto_name(
174            op_name,
175            NodeType::Operation(op_name.to_string()),
176            device,
177            attributes,
178        )?;
179
180        // Connect inputs to the operation
181        for (input_idx, &input_node) in inputs.iter().enumerate() {
182            if !self.nodes.contains_key(&input_node) {
183                return Err(TensorError::invalid_argument(format!(
184                    "Input node {} not found",
185                    input_node
186                )));
187            }
188
189            self.add_edge(
190                input_node,
191                op_node,
192                0, // Assume single output from input node
193                input_idx,
194                crate::dtype::DType::Float32, // Default - should be inferred
195                crate::shape::Shape::new(vec![]),
196                false,
197            )?;
198        }
199
200        // Create output nodes if multiple outputs
201        let mut output_nodes = vec![op_node];
202        if output_shapes.len() > 1 {
203            for (output_idx, (shape, dtype)) in output_shapes
204                .iter()
205                .zip(output_dtypes.iter())
206                .enumerate()
207                .skip(1)
208            {
209                let output_node_name = format!("{}_output_{}", op_name, output_idx);
210                let output_node = self.add_node(
211                    output_node_name,
212                    NodeType::Operation("Identity".to_string()),
213                    device,
214                    HashMap::new(),
215                )?;
216
217                self.add_edge(
218                    op_node,
219                    output_node,
220                    output_idx,
221                    0,
222                    *dtype,
223                    shape.clone(),
224                    false,
225                )?;
226
227                output_nodes.push(output_node);
228            }
229        }
230
231        Ok(output_nodes)
232    }
233
234    /// Insert a new node between two existing connected nodes
235    pub fn insert_node_between(
236        &mut self,
237        from_node: NodeId,
238        to_node: NodeId,
239        new_node_name: String,
240        new_node_type: NodeType,
241        device: Device,
242        attributes: HashMap<String, AttributeValue>,
243    ) -> Result<NodeId, TensorError> {
244        // Find the edge to replace
245        let edge_to_replace = self
246            .edges
247            .values()
248            .find(|edge| edge.from_node == from_node && edge.to_node == to_node && !edge.is_control)
249            .cloned();
250
251        let edge = edge_to_replace.ok_or_else(|| {
252            TensorError::invalid_argument(format!(
253                "No data edge found between nodes {} and {}",
254                from_node, to_node
255            ))
256        })?;
257
258        // Create the new node
259        let new_node = self.add_node(new_node_name, new_node_type, device, attributes)?;
260
261        // Remove the original edge
262        self.remove_edge(edge.id)?;
263
264        // Create new edges: from_node -> new_node -> to_node
265        self.add_edge(
266            from_node,
267            new_node,
268            edge.from_output,
269            0,
270            edge.dtype,
271            edge.shape.clone(),
272            false,
273        )?;
274
275        self.add_edge(
276            new_node,
277            to_node,
278            0,
279            edge.to_input,
280            edge.dtype,
281            edge.shape,
282            false,
283        )?;
284
285        Ok(new_node)
286    }
287
288    /// Replace a node with a subgraph
289    pub fn replace_node_with_subgraph(
290        &mut self,
291        node_to_replace: NodeId,
292        replacement_graph: &Graph,
293        input_mapping: &HashMap<usize, NodeId>, // input_index -> replacement_node_id
294        output_mapping: &HashMap<usize, NodeId>, // output_index -> replacement_node_id
295    ) -> Result<HashMap<NodeId, NodeId>, TensorError> {
296        let node = self
297            .nodes
298            .get(&node_to_replace)
299            .ok_or_else(|| {
300                TensorError::invalid_argument(format!("Node {} not found", node_to_replace))
301            })?
302            .clone();
303
304        // Store incoming and outgoing edges
305        let incoming_edges: Vec<_> = node
306            .inputs
307            .iter()
308            .filter_map(|&edge_id| self.edges.get(&edge_id))
309            .cloned()
310            .collect();
311
312        let outgoing_edges: Vec<_> = node
313            .outputs
314            .iter()
315            .filter_map(|&edge_id| self.edges.get(&edge_id))
316            .cloned()
317            .collect();
318
319        // Remove the original node
320        self.remove_node(node_to_replace)?;
321
322        // Add the replacement graph
323        let id_mapping = self.extend_with_graph(replacement_graph, Some(&node.name))?;
324
325        // Reconnect incoming edges
326        for edge in incoming_edges {
327            if let Some(&replacement_input) = input_mapping.get(&edge.to_input) {
328                if let Some(&mapped_node) = id_mapping.get(&replacement_input) {
329                    self.add_edge(
330                        edge.from_node,
331                        mapped_node,
332                        edge.from_output,
333                        0, // Connect to first input of replacement node
334                        edge.dtype,
335                        edge.shape,
336                        edge.is_control,
337                    )?;
338                }
339            }
340        }
341
342        // Reconnect outgoing edges
343        for edge in outgoing_edges {
344            if let Some(&replacement_output) = output_mapping.get(&edge.from_output) {
345                if let Some(&mapped_node) = id_mapping.get(&replacement_output) {
346                    self.add_edge(
347                        mapped_node,
348                        edge.to_node,
349                        0, // Connect from first output of replacement node
350                        edge.to_input,
351                        edge.dtype,
352                        edge.shape,
353                        edge.is_control,
354                    )?;
355                }
356            }
357        }
358
359        Ok(id_mapping)
360    }
361}