Skip to main content

tenflowers_core/graph/
core.rs

1//! Core Graph Types and Basic Operations
2//!
3//! This module contains the fundamental data structures for the computation graph
4//! and basic operations for creating and accessing graph elements.
5
6use crate::{device::Device, dtype::DType, error::TensorError, shape::Shape, tensor::Tensor};
7use std::collections::HashMap;
8
9/// Unique identifier for nodes in the graph
10pub type NodeId = u64;
11
12/// Unique identifier for edges in the graph
13pub type EdgeId = u64;
14
15/// Graph node representing an operation, variable, placeholder, or constant
16#[derive(Clone, Debug)]
17pub struct GraphNode {
18    pub id: NodeId,
19    pub name: String,
20    pub op_type: NodeType,
21    pub device: Device,
22    pub inputs: Vec<EdgeId>,
23    pub outputs: Vec<EdgeId>,
24    pub attributes: HashMap<String, AttributeValue>,
25}
26
27/// Types of nodes in the computation graph
28#[derive(Clone, Debug, PartialEq)]
29pub enum NodeType {
30    /// Operation node that performs computation
31    Operation(String), // op name
32    /// Variable node that holds mutable state
33    Variable { dtype: DType, shape: Shape },
34    /// Placeholder node for feeding inputs
35    Placeholder { dtype: DType, shape: Shape },
36    /// Constant node with fixed value
37    Constant,
38}
39
40/// Edge representing data flow between nodes
41#[derive(Clone, Debug)]
42pub struct GraphEdge {
43    pub id: EdgeId,
44    pub from_node: NodeId,
45    pub to_node: NodeId,
46    pub from_output: usize, // output index from source node
47    pub to_input: usize,    // input index to destination node
48    pub dtype: DType,
49    pub shape: Shape,
50    pub is_control: bool, // true for control dependencies
51}
52
53/// Attribute values that can be attached to nodes
54#[derive(Clone, Debug)]
55pub enum AttributeValue {
56    String(String),
57    Int(i64),
58    Float(f64),
59    Bool(bool),
60    IntList(Vec<i64>),
61    FloatList(Vec<f64>),
62    Shape(Shape),
63    Tensor(Tensor<f32>), // For constants
64}
65
66/// Main computation graph structure
67#[derive(Debug, Clone)]
68pub struct Graph {
69    pub(crate) nodes: HashMap<NodeId, GraphNode>,
70    pub(crate) edges: HashMap<EdgeId, GraphEdge>,
71    pub(crate) next_node_id: NodeId,
72    pub(crate) next_edge_id: EdgeId,
73    pub(crate) name_to_node: HashMap<String, NodeId>,
74    pub(crate) topological_order: Option<Vec<NodeId>>,
75    pub(crate) version: u64,
76}
77
78impl Graph {
79    /// Create a new empty graph
80    pub fn new() -> Self {
81        Self {
82            nodes: HashMap::new(),
83            edges: HashMap::new(),
84            next_node_id: 0,
85            next_edge_id: 0,
86            name_to_node: HashMap::new(),
87            topological_order: None,
88            version: 0,
89        }
90    }
91
92    /// Add a new node to the graph
93    pub fn add_node(
94        &mut self,
95        name: String,
96        op_type: NodeType,
97        device: Device,
98        attributes: HashMap<String, AttributeValue>,
99    ) -> Result<NodeId, TensorError> {
100        // Ensure unique names
101        if self.name_to_node.contains_key(&name) {
102            return Err(TensorError::invalid_argument(format!(
103                "Node name '{name}' already exists"
104            )));
105        }
106
107        let node_id = self.next_node_id;
108        self.next_node_id += 1;
109
110        let node = GraphNode {
111            id: node_id,
112            name: name.clone(),
113            op_type,
114            device,
115            inputs: Vec::new(),
116            outputs: Vec::new(),
117            attributes,
118        };
119
120        self.nodes.insert(node_id, node);
121        self.name_to_node.insert(name, node_id);
122        self.topological_order = None; // Invalidate cached order
123        self.version += 1;
124
125        Ok(node_id)
126    }
127
128    /// Add a new edge to the graph
129    #[allow(clippy::too_many_arguments)]
130    pub fn add_edge(
131        &mut self,
132        from_node: NodeId,
133        to_node: NodeId,
134        from_output: usize,
135        to_input: usize,
136        dtype: DType,
137        shape: Shape,
138        is_control: bool,
139    ) -> Result<EdgeId, TensorError> {
140        // Validate nodes exist
141        if !self.nodes.contains_key(&from_node) {
142            return Err(TensorError::invalid_argument(format!(
143                "Source node {from_node} not found"
144            )));
145        }
146        if !self.nodes.contains_key(&to_node) {
147            return Err(TensorError::invalid_argument(format!(
148                "Destination node {to_node} not found"
149            )));
150        }
151
152        // Check for cycles (simplified - just direct cycle)
153        if from_node == to_node {
154            return Err(TensorError::invalid_argument(
155                "Self-loops are not allowed".to_string(),
156            ));
157        }
158
159        let edge_id = self.next_edge_id;
160        self.next_edge_id += 1;
161
162        let edge = GraphEdge {
163            id: edge_id,
164            from_node,
165            to_node,
166            from_output,
167            to_input,
168            dtype,
169            shape,
170            is_control,
171        };
172
173        self.edges.insert(edge_id, edge);
174
175        // Update node edge lists
176        self.nodes
177            .get_mut(&from_node)
178            .expect("Source node must exist after validation")
179            .outputs
180            .push(edge_id);
181        self.nodes
182            .get_mut(&to_node)
183            .expect("Destination node must exist after validation")
184            .inputs
185            .push(edge_id);
186
187        self.topological_order = None; // Invalidate cached order
188        self.version += 1;
189
190        Ok(edge_id)
191    }
192
193    /// Get a node by ID
194    pub fn get_node(&self, node_id: NodeId) -> Option<&GraphNode> {
195        self.nodes.get(&node_id)
196    }
197
198    /// Get a node by name
199    pub fn get_node_by_name(&self, name: &str) -> Option<&GraphNode> {
200        self.name_to_node
201            .get(name)
202            .and_then(|&id| self.nodes.get(&id))
203    }
204
205    /// Get an edge by ID
206    pub fn get_edge(&self, edge_id: EdgeId) -> Option<&GraphEdge> {
207        self.edges.get(&edge_id)
208    }
209
210    /// Iterate over all nodes
211    pub fn nodes(&self) -> impl Iterator<Item = &GraphNode> {
212        self.nodes.values()
213    }
214
215    /// Iterate over all edges
216    pub fn edges(&self) -> impl Iterator<Item = &GraphEdge> {
217        self.edges.values()
218    }
219
220    /// Get a mutable reference to a node
221    pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut GraphNode> {
222        self.nodes.get_mut(&node_id)
223    }
224
225    /// Get a mutable reference to an edge
226    pub fn get_edge_mut(&mut self, edge_id: EdgeId) -> Option<&mut GraphEdge> {
227        self.edges.get_mut(&edge_id)
228    }
229
230    /// Get the number of nodes
231    pub fn node_count(&self) -> usize {
232        self.nodes.len()
233    }
234
235    /// Get the number of edges
236    pub fn edge_count(&self) -> usize {
237        self.edges.len()
238    }
239}
240
241impl Default for Graph {
242    fn default() -> Self {
243        Self::new()
244    }
245}