tenflowers_core/graph/
subgraph.rs1use super::core::*;
7use crate::{device::Device, error::TensorError};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Graph {
11 pub fn subgraph(&self, node_ids: &[NodeId]) -> Result<Graph, TensorError> {
13 let node_set: HashSet<NodeId> = node_ids.iter().cloned().collect();
14 let mut subgraph = Graph::new();
15 let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
16
17 for &node_id in node_ids {
19 if let Some(node) = self.nodes.get(&node_id) {
20 let new_id = subgraph.add_node(
21 node.name.clone(),
22 node.op_type.clone(),
23 node.device,
24 node.attributes.clone(),
25 )?;
26 id_mapping.insert(node_id, new_id);
27 } else {
28 return Err(TensorError::invalid_argument(format!(
29 "Node {} not found in graph",
30 node_id
31 )));
32 }
33 }
34
35 for edge in self.edges.values() {
37 if node_set.contains(&edge.from_node) && node_set.contains(&edge.to_node) {
38 let new_from = *id_mapping
39 .get(&edge.from_node)
40 .expect("Node ID must exist in mapping after insertion");
41 let new_to = *id_mapping
42 .get(&edge.to_node)
43 .expect("Node ID must exist in mapping after insertion");
44
45 subgraph.add_edge(
46 new_from,
47 new_to,
48 edge.from_output,
49 edge.to_input,
50 edge.dtype,
51 edge.shape.clone(),
52 edge.is_control,
53 )?;
54 }
55 }
56
57 Ok(subgraph)
58 }
59
60 pub fn subgraph_by_op_types(&self, op_types: &[&str]) -> Result<Graph, TensorError> {
62 let node_ids: Vec<NodeId> = self
63 .nodes
64 .values()
65 .filter(|node| match &node.op_type {
66 NodeType::Operation(op_name) => op_types.contains(&op_name.as_str()),
67 _ => false,
68 })
69 .map(|node| node.id)
70 .collect();
71
72 self.subgraph(&node_ids)
73 }
74
75 pub fn subgraph_by_device(&self, device: Device) -> Result<Graph, TensorError> {
77 let node_ids: Vec<NodeId> = self
78 .nodes
79 .values()
80 .filter(|node| node.device == device)
81 .map(|node| node.id)
82 .collect();
83
84 self.subgraph(&node_ids)
85 }
86
87 pub fn subgraph_with_dependencies(
89 &self,
90 root_nodes: &[NodeId],
91 include_control_deps: bool,
92 ) -> Result<Graph, TensorError> {
93 let mut included_nodes = HashSet::new();
94 let mut queue = VecDeque::new();
95
96 for &node_id in root_nodes {
98 if self.nodes.contains_key(&node_id) {
99 queue.push_back(node_id);
100 included_nodes.insert(node_id);
101 } else {
102 return Err(TensorError::invalid_argument(format!(
103 "Node {} not found in graph",
104 node_id
105 )));
106 }
107 }
108
109 while let Some(node_id) = queue.pop_front() {
111 if let Some(node) = self.nodes.get(&node_id) {
112 for &edge_id in &node.inputs {
113 if let Some(edge) = self.edges.get(&edge_id) {
114 if (!edge.is_control || include_control_deps)
115 && !included_nodes.contains(&edge.from_node)
116 {
117 included_nodes.insert(edge.from_node);
118 queue.push_back(edge.from_node);
119 }
120 }
121 }
122 }
123 }
124
125 let node_ids: Vec<NodeId> = included_nodes.into_iter().collect();
126 self.subgraph(&node_ids)
127 }
128
129 pub fn connected_component(&self, start_node: NodeId) -> Result<Graph, TensorError> {
131 if !self.nodes.contains_key(&start_node) {
132 return Err(TensorError::invalid_argument(format!(
133 "Node {} not found in graph",
134 start_node
135 )));
136 }
137
138 let mut visited = HashSet::new();
139 let mut queue = VecDeque::new();
140
141 queue.push_back(start_node);
142 visited.insert(start_node);
143
144 while let Some(node_id) = queue.pop_front() {
146 if let Some(node) = self.nodes.get(&node_id) {
147 for &edge_id in &node.inputs {
149 if let Some(edge) = self.edges.get(&edge_id) {
150 if !visited.contains(&edge.from_node) {
151 visited.insert(edge.from_node);
152 queue.push_back(edge.from_node);
153 }
154 }
155 }
156 for &edge_id in &node.outputs {
158 if let Some(edge) = self.edges.get(&edge_id) {
159 if !visited.contains(&edge.to_node) {
160 visited.insert(edge.to_node);
161 queue.push_back(edge.to_node);
162 }
163 }
164 }
165 }
166 }
167
168 let node_ids: Vec<NodeId> = visited.into_iter().collect();
169 self.subgraph(&node_ids)
170 }
171
172 pub fn forward_slice(
174 &self,
175 start_nodes: &[NodeId],
176 include_control_deps: bool,
177 ) -> Result<Graph, TensorError> {
178 let mut included_nodes = HashSet::new();
179 let mut queue = VecDeque::new();
180
181 for &node_id in start_nodes {
183 if self.nodes.contains_key(&node_id) {
184 queue.push_back(node_id);
185 included_nodes.insert(node_id);
186 } else {
187 return Err(TensorError::invalid_argument(format!(
188 "Node {} not found in graph",
189 node_id
190 )));
191 }
192 }
193
194 while let Some(node_id) = queue.pop_front() {
196 if let Some(node) = self.nodes.get(&node_id) {
197 for &edge_id in &node.outputs {
198 if let Some(edge) = self.edges.get(&edge_id) {
199 if (!edge.is_control || include_control_deps)
200 && !included_nodes.contains(&edge.to_node)
201 {
202 included_nodes.insert(edge.to_node);
203 queue.push_back(edge.to_node);
204 }
205 }
206 }
207 }
208 }
209
210 let node_ids: Vec<NodeId> = included_nodes.into_iter().collect();
211 self.subgraph(&node_ids)
212 }
213
214 pub fn subgraph_by_predicate<F>(&self, predicate: F) -> Result<Graph, TensorError>
216 where
217 F: Fn(&GraphNode) -> bool,
218 {
219 let node_ids: Vec<NodeId> = self
220 .nodes
221 .values()
222 .filter(|node| predicate(node))
223 .map(|node| node.id)
224 .collect();
225
226 self.subgraph(&node_ids)
227 }
228}