tensor_forge/node.rs
1//! Representations one operation instance in an ML graph.
2
3use crate::op::OpKind;
4
5/// Represents the `NodeId` for a specific node in the graph.
6#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
7pub struct NodeId(usize);
8
9/// Struct for a graph node in the ML operations sequence.
10///
11/// They are constructed
12/// automatically by interacting with operations in a [`crate::graph::Graph`] struct.
13///
14/// Each node represents a particular action as determined by the
15/// [`OpKind`] field.
16///
17/// # Examples
18/// #TODO
19#[derive(Debug, PartialEq, Eq, Hash)]
20pub struct Node {
21 /// ID of the current node.
22 ///
23 /// This is automatically generated on Node creation.
24 pub id: NodeId,
25 /// Operation of the current node in the ML pipeline. See [`OpKind`].
26 pub op: OpKind,
27 /// Node IDs of the inputs to this operation.
28 pub inputs: Vec<NodeId>,
29 /// Tensor dimensions (shape) of the output tensor produced by this node.
30 pub shape: Vec<usize>,
31}
32
33use std::sync::atomic::{AtomicU32, Ordering};
34
35static ID_COUNTER: AtomicU32 = AtomicU32::new(0);
36
37impl Node {
38 /// Rather than use reference-counted pointers, `graph` contains
39 /// a list of all the valid nodes. Input-output pairs are generated
40 /// by examining the indices and forming a DAG. See [`graph.rs`] for more information.
41 pub(crate) fn new(op: OpKind, inputs: Vec<NodeId>, shape: Vec<usize>) -> Self {
42 let node_id = NodeId(ID_COUNTER.fetch_add(1, Ordering::SeqCst) as usize);
43 Self {
44 id: node_id,
45 op,
46 inputs,
47 shape,
48 }
49 }
50}