tenflowers_core/graph/
control_deps.rs1use super::core::*;
7use crate::error::TensorError;
8
9impl Graph {
10 pub fn add_control_dependency(
12 &mut self,
13 from_node: NodeId,
14 to_node: NodeId,
15 ) -> Result<EdgeId, TensorError> {
16 self.add_edge(
17 from_node,
18 to_node,
19 0,
20 0,
21 crate::dtype::DType::Float32, crate::shape::Shape::new(vec![]),
23 true, )
25 }
26
27 pub fn add_control_dependencies(
29 &mut self,
30 from_node: NodeId,
31 to_nodes: &[NodeId],
32 ) -> Result<Vec<EdgeId>, TensorError> {
33 let mut edge_ids = Vec::new();
34 for &to_node in to_nodes {
35 let edge_id = self.add_control_dependency(from_node, to_node)?;
36 edge_ids.push(edge_id);
37 }
38 Ok(edge_ids)
39 }
40
41 pub fn get_control_dependencies(&self, node_id: NodeId) -> Vec<NodeId> {
43 let node = match self.nodes.get(&node_id) {
44 Some(node) => node,
45 None => return Vec::new(),
46 };
47
48 node.inputs
49 .iter()
50 .filter_map(|&edge_id| self.edges.get(&edge_id))
51 .filter(|edge| edge.is_control)
52 .map(|edge| edge.from_node)
53 .collect()
54 }
55
56 pub fn get_control_dependents(&self, node_id: NodeId) -> Vec<NodeId> {
58 let node = match self.nodes.get(&node_id) {
59 Some(node) => node,
60 None => return Vec::new(),
61 };
62
63 node.outputs
64 .iter()
65 .filter_map(|&edge_id| self.edges.get(&edge_id))
66 .filter(|edge| edge.is_control)
67 .map(|edge| edge.to_node)
68 .collect()
69 }
70
71 pub fn has_control_dependency(&self, from_node: NodeId, to_node: NodeId) -> bool {
73 self.edges
74 .values()
75 .any(|edge| edge.is_control && edge.from_node == from_node && edge.to_node == to_node)
76 }
77
78 pub fn remove_control_dependencies(&mut self, node_id: NodeId) -> Result<usize, TensorError> {
80 let node = self
81 .nodes
82 .get(&node_id)
83 .ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
84
85 let mut control_edges_to_remove = Vec::new();
86
87 for &edge_id in &node.inputs {
89 if let Some(edge) = self.edges.get(&edge_id) {
90 if edge.is_control {
91 control_edges_to_remove.push(edge_id);
92 }
93 }
94 }
95
96 for &edge_id in &node.outputs {
98 if let Some(edge) = self.edges.get(&edge_id) {
99 if edge.is_control {
100 control_edges_to_remove.push(edge_id);
101 }
102 }
103 }
104
105 let removed_count = control_edges_to_remove.len();
106
107 for edge_id in control_edges_to_remove {
109 self.remove_edge(edge_id)?;
110 }
111
112 Ok(removed_count)
113 }
114
115 pub fn create_control_context(&mut self, context_nodes: &[NodeId]) -> Result<(), TensorError> {
117 for &node_id in context_nodes {
119 if !self.nodes.contains_key(&node_id) {
120 return Err(TensorError::invalid_argument(format!(
121 "Node {} not found",
122 node_id
123 )));
124 }
125 }
126
127 for window in context_nodes.windows(2) {
129 self.add_control_dependency(window[0], window[1])?;
130 }
131
132 Ok(())
133 }
134}