Skip to main content

tenflowers_core/graph/
control_deps.rs

1//! Control Dependency Management
2//!
3//! This module provides functionality for managing control dependencies
4//! between nodes in the computation graph.
5
6use super::core::*;
7use crate::error::TensorError;
8
9impl Graph {
10    /// Add a control dependency between two nodes
11    pub fn add_control_dependency(
12        &mut self,
13        from_node: NodeId,
14        to_node: NodeId,
15    ) -> Result<EdgeId, TensorError> {
16        self.add_edge(
17            from_node,
18            to_node,
19            0,
20            0,
21            crate::dtype::DType::Float32, // Control dependencies don't have meaningful types
22            crate::shape::Shape::new(vec![]),
23            true, // This is a control edge
24        )
25    }
26
27    /// Add control dependencies from one node to multiple nodes
28    pub fn add_control_dependencies(
29        &mut self,
30        from_node: NodeId,
31        to_nodes: &[NodeId],
32    ) -> Result<Vec<EdgeId>, TensorError> {
33        let mut edge_ids = Vec::new();
34        for &to_node in to_nodes {
35            let edge_id = self.add_control_dependency(from_node, to_node)?;
36            edge_ids.push(edge_id);
37        }
38        Ok(edge_ids)
39    }
40
41    /// Get all nodes that this node has control dependencies on
42    pub fn get_control_dependencies(&self, node_id: NodeId) -> Vec<NodeId> {
43        let node = match self.nodes.get(&node_id) {
44            Some(node) => node,
45            None => return Vec::new(),
46        };
47
48        node.inputs
49            .iter()
50            .filter_map(|&edge_id| self.edges.get(&edge_id))
51            .filter(|edge| edge.is_control)
52            .map(|edge| edge.from_node)
53            .collect()
54    }
55
56    /// Get all nodes that depend on this node via control dependencies
57    pub fn get_control_dependents(&self, node_id: NodeId) -> Vec<NodeId> {
58        let node = match self.nodes.get(&node_id) {
59            Some(node) => node,
60            None => return Vec::new(),
61        };
62
63        node.outputs
64            .iter()
65            .filter_map(|&edge_id| self.edges.get(&edge_id))
66            .filter(|edge| edge.is_control)
67            .map(|edge| edge.to_node)
68            .collect()
69    }
70
71    /// Check if there's a control dependency between two nodes
72    pub fn has_control_dependency(&self, from_node: NodeId, to_node: NodeId) -> bool {
73        self.edges
74            .values()
75            .any(|edge| edge.is_control && edge.from_node == from_node && edge.to_node == to_node)
76    }
77
78    /// Remove all control dependencies for a node (both incoming and outgoing)
79    pub fn remove_control_dependencies(&mut self, node_id: NodeId) -> Result<usize, TensorError> {
80        let node = self
81            .nodes
82            .get(&node_id)
83            .ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
84
85        let mut control_edges_to_remove = Vec::new();
86
87        // Find incoming control edges
88        for &edge_id in &node.inputs {
89            if let Some(edge) = self.edges.get(&edge_id) {
90                if edge.is_control {
91                    control_edges_to_remove.push(edge_id);
92                }
93            }
94        }
95
96        // Find outgoing control edges
97        for &edge_id in &node.outputs {
98            if let Some(edge) = self.edges.get(&edge_id) {
99                if edge.is_control {
100                    control_edges_to_remove.push(edge_id);
101                }
102            }
103        }
104
105        let removed_count = control_edges_to_remove.len();
106
107        // Remove the control edges
108        for edge_id in control_edges_to_remove {
109            self.remove_edge(edge_id)?;
110        }
111
112        Ok(removed_count)
113    }
114
115    /// Create a control context - ensure all nodes in the list execute in order
116    pub fn create_control_context(&mut self, context_nodes: &[NodeId]) -> Result<(), TensorError> {
117        // Validate all nodes exist
118        for &node_id in context_nodes {
119            if !self.nodes.contains_key(&node_id) {
120                return Err(TensorError::invalid_argument(format!(
121                    "Node {} not found",
122                    node_id
123                )));
124            }
125        }
126
127        // Create control dependencies between consecutive nodes
128        for window in context_nodes.windows(2) {
129            self.add_control_dependency(window[0], window[1])?;
130        }
131
132        Ok(())
133    }
134}