Skip to main content

tensor_forge/
node.rs

1//! Representations one operation instance in an ML graph.
2
3use crate::op::OpKind;
4
5/// Represents the `NodeId` for a specific node in the graph.
6#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
7pub struct NodeId(usize);
8
9/// Struct for a graph node in the ML operations sequence.
10///
11/// They are constructed
12/// automatically by interacting with operations in a [`crate::graph::Graph`] struct.
13///
14/// Each node represents a particular action as determined by the
15/// [`OpKind`] field.
16///
17/// # Examples
18/// #TODO
19#[derive(Debug, PartialEq, Eq, Hash)]
20pub struct Node {
21    /// ID of the current node.
22    ///
23    /// This is automatically generated on Node creation.
24    pub id: NodeId,
25    /// Operation of the current node in the ML pipeline. See [`OpKind`].
26    pub op: OpKind,
27    /// Node IDs of the inputs to this operation.
28    pub inputs: Vec<NodeId>,
29    /// Tensor dimensions (shape) of the output tensor produced by this node.
30    pub shape: Vec<usize>,
31}
32
33use std::sync::atomic::{AtomicU32, Ordering};
34
35static ID_COUNTER: AtomicU32 = AtomicU32::new(0);
36
37impl Node {
38    /// Rather than use reference-counted pointers, `graph` contains
39    /// a list of all the valid nodes. Input-output pairs are generated
40    /// by examining the indices and forming a DAG. See [`graph.rs`] for more information.
41    pub(crate) fn new(op: OpKind, inputs: Vec<NodeId>, shape: Vec<usize>) -> Self {
42        let node_id = NodeId(ID_COUNTER.fetch_add(1, Ordering::SeqCst) as usize);
43        Self {
44            id: node_id,
45            op,
46            inputs,
47            shape,
48        }
49    }
50}