zyx_core/
node.rs

1extern crate alloc;
2use crate::dtype::DType;
3use crate::utils::get_shape;
4use crate::{axes::Axes, shape::Shape, tensor::Id};
5use alloc::boxed::Box;
6use core::fmt::Formatter;
7
8/// Node representing different possible tensors
9pub enum Node {
10    /// Detach tensor from tape
11    Detach(Id),
12    /// Leaf that is guaranteed to be evaluated
13    Leaf(Shape, DType),
14    /// Uniform initializer for range 0..1
15    Uniform(Shape, DType),
16    /// Cast to dtype unary op
17    Cast(Id, DType),
18    /// Neg unary op
19    Neg(Id),
20    /// ReLU unary op
21    ReLU(Id),
22    /// Sine unary op
23    Sin(Id),
24    /// Cosine unary op
25    Cos(Id),
26    /// Natural logarithm unary op
27    Ln(Id),
28    /// Exp unary op
29    Exp(Id),
30    /// Hyperbolic tangent unary op
31    Tanh(Id),
32    /// Square root unary op
33    Sqrt(Id),
34    /// Addition binary op
35    Add(Id, Id),
36    /// Subtraction binary op
37    Sub(Id, Id),
38    /// Multiplication binary op
39    Mul(Id, Id),
40    /// Division binary op
41    Div(Id, Id),
42    /// Exponentiation binary op
43    Pow(Id, Id),
44    /// Compare less than binary op
45    Cmplt(Id, Id),
46    /// Where op
47    Where(Id, Id, Id),
48    /// Reshape movement op
49    Reshape(Id, Shape),
50    /// Expand movement op
51    Expand(Id, Shape),
52    /// Permute movement op
53    Permute(Id, Axes, Shape),
54    /// Pad movement op
55    Pad(Id, Box<[(i64, i64)]>, Shape),
56    /// Sum reduce op
57    Sum(Id, Axes, Shape),
58    /// Max reduce op
59    Max(Id, Axes, Shape),
60}
61
62impl core::fmt::Debug for Node {
63    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
64        match self {
65            Node::Detach(x) => f.write_fmt(format_args!("Detach({x})")),
66            Node::Leaf(sh, dtype) => f.write_fmt(format_args!("Leaf({sh}, {dtype})")),
67            Node::Cast(x, dtype) => f.write_fmt(format_args!("Cast({x}, {dtype})")),
68            Node::Uniform(sh, dtype) => f.write_fmt(format_args!("Uniform({sh}, {dtype})")),
69            Node::Neg(x) => f.write_fmt(format_args!("Neg({x})")),
70            Node::ReLU(x) => f.write_fmt(format_args!("ReLU({x})")),
71            Node::Sin(x) => f.write_fmt(format_args!("Sin({x})")),
72            Node::Cos(x) => f.write_fmt(format_args!("Cos({x})")),
73            Node::Ln(x) => f.write_fmt(format_args!("Ln({x})")),
74            Node::Exp(x) => f.write_fmt(format_args!("Exp({x})")),
75            Node::Tanh(x) => f.write_fmt(format_args!("Tanh({x})")),
76            Node::Sqrt(x) => f.write_fmt(format_args!("Sqrt({x})")),
77            Node::Add(x, y) => f.write_fmt(format_args!("Add({x}, {y})")),
78            Node::Sub(x, y) => f.write_fmt(format_args!("Sub({x}, {y})")),
79            Node::Mul(x, y) => f.write_fmt(format_args!("Mul({x}, {y})")),
80            Node::Div(x, y) => f.write_fmt(format_args!("Div({x}, {y})")),
81            Node::Pow(x, y) => f.write_fmt(format_args!("Pow({x}, {y})")),
82            Node::Cmplt(x, y) => f.write_fmt(format_args!("Cmplt({x}, {y})")),
83            Node::Where(x, y, z) => f.write_fmt(format_args!("Where({x}, {y}, {z})")),
84            Node::Expand(x, sh) => f.write_fmt(format_args!("Expand({x}, {sh})")),
85            Node::Reshape(x, sh) => f.write_fmt(format_args!("Reshape({x}, {sh})")),
86            Node::Pad(x, padding, ..) => f.write_fmt(format_args!("Pad({x}, {padding:?})")),
87            Node::Permute(x, ax, ..) => f.write_fmt(format_args!("Permute({x}, {ax})")),
88            Node::Sum(x, ax, sh) => f.write_fmt(format_args!("Sum({x}, {ax}, {sh})")),
89            Node::Max(x, ax, sh) => f.write_fmt(format_args!("Max({x}, {ax}, {sh})")),
90        }
91    }
92}
93
94/// Iterator over parameters of node which does not allocate on heap.
95pub struct NodeParametersIterator {
96    parameters: [Id; 3],
97    len: u8,
98    idx: u8,
99}
100
101impl Iterator for NodeParametersIterator {
102    type Item = Id;
103    fn next(&mut self) -> Option<Self::Item> {
104        if self.idx == self.len {
105            return None;
106        }
107        let idx = self.idx;
108        self.idx += 1;
109        Some(self.parameters[idx as usize])
110    }
111}
112
113impl Node {
114    /// Get number of parameters of self. This method does not allocate.
115    pub const fn num_parameters(&self) -> u8 {
116        match self {
117            Node::Leaf(..) | Node::Uniform(..) => 0,
118            Node::Detach(..)
119            | Node::Cast(..)
120            | Node::Neg(..)
121            | Node::ReLU(..)
122            | Node::Exp(..)
123            | Node::Ln(..)
124            | Node::Sin(..)
125            | Node::Cos(..)
126            | Node::Sqrt(..)
127            | Node::Tanh(..)
128            | Node::Reshape(..)
129            | Node::Expand(..)
130            | Node::Permute(..)
131            | Node::Pad(..)
132            | Node::Sum(..)
133            | Node::Max(..) => 1,
134            Node::Add(..)
135            | Node::Sub(..)
136            | Node::Mul(..)
137            | Node::Div(..)
138            | Node::Cmplt(..)
139            | Node::Pow(..) => 2,
140            Node::Where(..) => 3,
141        }
142    }
143
144    /// Get all parameters of self. This method does not allocate.
145    pub const fn parameters(&self) -> impl Iterator<Item = Id> {
146        match self {
147            Node::Leaf(..) | Node::Uniform(..) => NodeParametersIterator {
148                parameters: [crate::tensor::id(0); 3],
149                idx: 0,
150                len: 0,
151            },
152            Node::Cast(x, ..)
153            | Node::Detach(x)
154            | Node::Neg(x)
155            | Node::ReLU(x)
156            | Node::Exp(x)
157            | Node::Ln(x)
158            | Node::Sin(x)
159            | Node::Cos(x)
160            | Node::Sqrt(x)
161            | Node::Tanh(x)
162            | Node::Reshape(x, ..)
163            | Node::Expand(x, ..)
164            | Node::Permute(x, ..)
165            | Node::Pad(x, ..)
166            | Node::Sum(x, ..)
167            | Node::Max(x, ..) => NodeParametersIterator {
168                parameters: [*x, crate::tensor::id(0), crate::tensor::id(0)],
169                idx: 0,
170                len: 1,
171            },
172            Node::Add(x, y)
173            | Node::Sub(x, y)
174            | Node::Mul(x, y)
175            | Node::Div(x, y)
176            | Node::Cmplt(x, y)
177            | Node::Pow(x, y) => NodeParametersIterator {
178                parameters: [*x, *y, crate::tensor::id(0)],
179                idx: 0,
180                len: 2,
181            },
182            Node::Where(x, y, z) => NodeParametersIterator {
183                parameters: [*x, *y, *z],
184                idx: 0,
185                len: 3,
186            },
187        }
188    }
189
190    /// Get number of operations necessary to calculate this node
191    pub fn flop(&self, nodes: &[Node]) -> usize {
192        match self {
193            Node::Detach(..)
194            | Node::Leaf(..)
195            | Node::Uniform(..)
196            | Node::Reshape(..)
197            | Node::Expand(..)
198            | Node::Permute(..)
199            | Node::Pad(..) => 0,
200            Node::Where(x, ..)
201            | Node::Add(x, _)
202            | Node::Sub(x, _)
203            | Node::Mul(x, _)
204            | Node::Div(x, _)
205            | Node::Cmplt(x, _)
206            | Node::Pow(x, _) => get_shape(nodes, *x).numel(), // x and y are guaranteed to be same shape
207            Node::Cast(x, ..)
208            | Node::Neg(x)
209            | Node::ReLU(x)
210            | Node::Exp(x)
211            | Node::Ln(x)
212            | Node::Sin(x)
213            | Node::Cos(x)
214            | Node::Sqrt(x)
215            | Node::Tanh(x) => get_shape(nodes, *x).numel(),
216            Node::Sum(x, _, sh) | Node::Max(x, _, sh) => {
217                let n = sh.numel();
218                let rdim = get_shape(nodes, *x).numel() / n;
219                rdim * n // technically (rdim-1)*n, but hardware needs to do rdim*n
220            }
221        }
222    }
223
224    /// Check if parameters of self contains nid.
225    pub fn parameters_contain(&self, nid: Id) -> bool {
226        match self {
227            Node::Leaf(..) | Node::Uniform(..) => false,
228            Node::Detach(x)
229            | Node::Cast(x, ..)
230            | Node::Neg(x)
231            | Node::ReLU(x)
232            | Node::Exp(x)
233            | Node::Ln(x)
234            | Node::Sin(x)
235            | Node::Cos(x)
236            | Node::Sqrt(x)
237            | Node::Tanh(x)
238            | Node::Sum(x, ..)
239            | Node::Max(x, ..)
240            | Node::Reshape(x, ..)
241            | Node::Expand(x, ..)
242            | Node::Permute(x, ..)
243            | Node::Pad(x, ..) => nid == *x,
244            Node::Add(x, y)
245            | Node::Sub(x, y)
246            | Node::Mul(x, y)
247            | Node::Div(x, y)
248            | Node::Cmplt(x, y)
249            | Node::Pow(x, y) => nid == *x || nid == *y,
250            Node::Where(x, y, z) => nid == *x || nid == *y || nid == *z,
251        }
252    }
253
254    /// Is this reduce node? (sum or max)
255    pub fn is_reduce(&self) -> bool {
256        matches!(self, Node::Sum(..) | Node::Max(..))
257    }
258}