Skip to main content

tenflowers_core/graph/
subgraph.rs

1//! Subgraph Operations
2//!
3//! This module provides functionality for extracting and creating subgraphs
4//! based on various criteria and traversal patterns.
5
6use super::core::*;
7use crate::{device::Device, error::TensorError};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Graph {
11    /// Create a subgraph containing only the specified nodes
12    pub fn subgraph(&self, node_ids: &[NodeId]) -> Result<Graph, TensorError> {
13        let node_set: HashSet<NodeId> = node_ids.iter().cloned().collect();
14        let mut subgraph = Graph::new();
15        let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
16
17        // Add nodes to subgraph
18        for &node_id in node_ids {
19            if let Some(node) = self.nodes.get(&node_id) {
20                let new_id = subgraph.add_node(
21                    node.name.clone(),
22                    node.op_type.clone(),
23                    node.device,
24                    node.attributes.clone(),
25                )?;
26                id_mapping.insert(node_id, new_id);
27            } else {
28                return Err(TensorError::invalid_argument(format!(
29                    "Node {} not found in graph",
30                    node_id
31                )));
32            }
33        }
34
35        // Add edges between included nodes
36        for edge in self.edges.values() {
37            if node_set.contains(&edge.from_node) && node_set.contains(&edge.to_node) {
38                let new_from = *id_mapping
39                    .get(&edge.from_node)
40                    .expect("Node ID must exist in mapping after insertion");
41                let new_to = *id_mapping
42                    .get(&edge.to_node)
43                    .expect("Node ID must exist in mapping after insertion");
44
45                subgraph.add_edge(
46                    new_from,
47                    new_to,
48                    edge.from_output,
49                    edge.to_input,
50                    edge.dtype,
51                    edge.shape.clone(),
52                    edge.is_control,
53                )?;
54            }
55        }
56
57        Ok(subgraph)
58    }
59
60    /// Create a subgraph containing nodes of specific operation types
61    pub fn subgraph_by_op_types(&self, op_types: &[&str]) -> Result<Graph, TensorError> {
62        let node_ids: Vec<NodeId> = self
63            .nodes
64            .values()
65            .filter(|node| match &node.op_type {
66                NodeType::Operation(op_name) => op_types.contains(&op_name.as_str()),
67                _ => false,
68            })
69            .map(|node| node.id)
70            .collect();
71
72        self.subgraph(&node_ids)
73    }
74
75    /// Create a subgraph containing nodes on a specific device
76    pub fn subgraph_by_device(&self, device: Device) -> Result<Graph, TensorError> {
77        let node_ids: Vec<NodeId> = self
78            .nodes
79            .values()
80            .filter(|node| node.device == device)
81            .map(|node| node.id)
82            .collect();
83
84        self.subgraph(&node_ids)
85    }
86
87    /// Create a subgraph with all dependencies of the specified nodes
88    pub fn subgraph_with_dependencies(
89        &self,
90        root_nodes: &[NodeId],
91        include_control_deps: bool,
92    ) -> Result<Graph, TensorError> {
93        let mut included_nodes = HashSet::new();
94        let mut queue = VecDeque::new();
95
96        // Start with root nodes
97        for &node_id in root_nodes {
98            if self.nodes.contains_key(&node_id) {
99                queue.push_back(node_id);
100                included_nodes.insert(node_id);
101            } else {
102                return Err(TensorError::invalid_argument(format!(
103                    "Node {} not found in graph",
104                    node_id
105                )));
106            }
107        }
108
109        // Traverse backwards through dependencies
110        while let Some(node_id) = queue.pop_front() {
111            if let Some(node) = self.nodes.get(&node_id) {
112                for &edge_id in &node.inputs {
113                    if let Some(edge) = self.edges.get(&edge_id) {
114                        if (!edge.is_control || include_control_deps)
115                            && !included_nodes.contains(&edge.from_node)
116                        {
117                            included_nodes.insert(edge.from_node);
118                            queue.push_back(edge.from_node);
119                        }
120                    }
121                }
122            }
123        }
124
125        let node_ids: Vec<NodeId> = included_nodes.into_iter().collect();
126        self.subgraph(&node_ids)
127    }
128
129    /// Find the connected component containing the specified node
130    pub fn connected_component(&self, start_node: NodeId) -> Result<Graph, TensorError> {
131        if !self.nodes.contains_key(&start_node) {
132            return Err(TensorError::invalid_argument(format!(
133                "Node {} not found in graph",
134                start_node
135            )));
136        }
137
138        let mut visited = HashSet::new();
139        let mut queue = VecDeque::new();
140
141        queue.push_back(start_node);
142        visited.insert(start_node);
143
144        // BFS to find all connected nodes
145        while let Some(node_id) = queue.pop_front() {
146            if let Some(node) = self.nodes.get(&node_id) {
147                // Check all input edges
148                for &edge_id in &node.inputs {
149                    if let Some(edge) = self.edges.get(&edge_id) {
150                        if !visited.contains(&edge.from_node) {
151                            visited.insert(edge.from_node);
152                            queue.push_back(edge.from_node);
153                        }
154                    }
155                }
156                // Check all output edges
157                for &edge_id in &node.outputs {
158                    if let Some(edge) = self.edges.get(&edge_id) {
159                        if !visited.contains(&edge.to_node) {
160                            visited.insert(edge.to_node);
161                            queue.push_back(edge.to_node);
162                        }
163                    }
164                }
165            }
166        }
167
168        let node_ids: Vec<NodeId> = visited.into_iter().collect();
169        self.subgraph(&node_ids)
170    }
171
172    /// Create a forward slice from the given nodes (includes all nodes reachable forward)
173    pub fn forward_slice(
174        &self,
175        start_nodes: &[NodeId],
176        include_control_deps: bool,
177    ) -> Result<Graph, TensorError> {
178        let mut included_nodes = HashSet::new();
179        let mut queue = VecDeque::new();
180
181        // Start with the specified nodes
182        for &node_id in start_nodes {
183            if self.nodes.contains_key(&node_id) {
184                queue.push_back(node_id);
185                included_nodes.insert(node_id);
186            } else {
187                return Err(TensorError::invalid_argument(format!(
188                    "Node {} not found in graph",
189                    node_id
190                )));
191            }
192        }
193
194        // Traverse forward through the graph
195        while let Some(node_id) = queue.pop_front() {
196            if let Some(node) = self.nodes.get(&node_id) {
197                for &edge_id in &node.outputs {
198                    if let Some(edge) = self.edges.get(&edge_id) {
199                        if (!edge.is_control || include_control_deps)
200                            && !included_nodes.contains(&edge.to_node)
201                        {
202                            included_nodes.insert(edge.to_node);
203                            queue.push_back(edge.to_node);
204                        }
205                    }
206                }
207            }
208        }
209
210        let node_ids: Vec<NodeId> = included_nodes.into_iter().collect();
211        self.subgraph(&node_ids)
212    }
213
214    /// Create a subgraph based on a custom predicate function
215    pub fn subgraph_by_predicate<F>(&self, predicate: F) -> Result<Graph, TensorError>
216    where
217        F: Fn(&GraphNode) -> bool,
218    {
219        let node_ids: Vec<NodeId> = self
220            .nodes
221            .values()
222            .filter(|node| predicate(node))
223            .map(|node| node.id)
224            .collect();
225
226        self.subgraph(&node_ids)
227    }
228}