1extern crate alloc;
2use crate::dtype::DType;
3use crate::utils::get_shape;
4use crate::{axes::Axes, shape::Shape, tensor::Id};
5use alloc::boxed::Box;
6use core::fmt::Formatter;
7
8pub enum Node {
10 Detach(Id),
12 Leaf(Shape, DType),
14 Uniform(Shape, DType),
16 Cast(Id, DType),
18 Neg(Id),
20 ReLU(Id),
22 Sin(Id),
24 Cos(Id),
26 Ln(Id),
28 Exp(Id),
30 Tanh(Id),
32 Sqrt(Id),
34 Add(Id, Id),
36 Sub(Id, Id),
38 Mul(Id, Id),
40 Div(Id, Id),
42 Pow(Id, Id),
44 Cmplt(Id, Id),
46 Where(Id, Id, Id),
48 Reshape(Id, Shape),
50 Expand(Id, Shape),
52 Permute(Id, Axes, Shape),
54 Pad(Id, Box<[(i64, i64)]>, Shape),
56 Sum(Id, Axes, Shape),
58 Max(Id, Axes, Shape),
60}
61
62impl core::fmt::Debug for Node {
63 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
64 match self {
65 Node::Detach(x) => f.write_fmt(format_args!("Detach({x})")),
66 Node::Leaf(sh, dtype) => f.write_fmt(format_args!("Leaf({sh}, {dtype})")),
67 Node::Cast(x, dtype) => f.write_fmt(format_args!("Cast({x}, {dtype})")),
68 Node::Uniform(sh, dtype) => f.write_fmt(format_args!("Uniform({sh}, {dtype})")),
69 Node::Neg(x) => f.write_fmt(format_args!("Neg({x})")),
70 Node::ReLU(x) => f.write_fmt(format_args!("ReLU({x})")),
71 Node::Sin(x) => f.write_fmt(format_args!("Sin({x})")),
72 Node::Cos(x) => f.write_fmt(format_args!("Cos({x})")),
73 Node::Ln(x) => f.write_fmt(format_args!("Ln({x})")),
74 Node::Exp(x) => f.write_fmt(format_args!("Exp({x})")),
75 Node::Tanh(x) => f.write_fmt(format_args!("Tanh({x})")),
76 Node::Sqrt(x) => f.write_fmt(format_args!("Sqrt({x})")),
77 Node::Add(x, y) => f.write_fmt(format_args!("Add({x}, {y})")),
78 Node::Sub(x, y) => f.write_fmt(format_args!("Sub({x}, {y})")),
79 Node::Mul(x, y) => f.write_fmt(format_args!("Mul({x}, {y})")),
80 Node::Div(x, y) => f.write_fmt(format_args!("Div({x}, {y})")),
81 Node::Pow(x, y) => f.write_fmt(format_args!("Pow({x}, {y})")),
82 Node::Cmplt(x, y) => f.write_fmt(format_args!("Cmplt({x}, {y})")),
83 Node::Where(x, y, z) => f.write_fmt(format_args!("Where({x}, {y}, {z})")),
84 Node::Expand(x, sh) => f.write_fmt(format_args!("Expand({x}, {sh})")),
85 Node::Reshape(x, sh) => f.write_fmt(format_args!("Reshape({x}, {sh})")),
86 Node::Pad(x, padding, ..) => f.write_fmt(format_args!("Pad({x}, {padding:?})")),
87 Node::Permute(x, ax, ..) => f.write_fmt(format_args!("Permute({x}, {ax})")),
88 Node::Sum(x, ax, sh) => f.write_fmt(format_args!("Sum({x}, {ax}, {sh})")),
89 Node::Max(x, ax, sh) => f.write_fmt(format_args!("Max({x}, {ax}, {sh})")),
90 }
91 }
92}
93
94pub struct NodeParametersIterator {
96 parameters: [Id; 3],
97 len: u8,
98 idx: u8,
99}
100
101impl Iterator for NodeParametersIterator {
102 type Item = Id;
103 fn next(&mut self) -> Option<Self::Item> {
104 if self.idx == self.len {
105 return None;
106 }
107 let idx = self.idx;
108 self.idx += 1;
109 Some(self.parameters[idx as usize])
110 }
111}
112
113impl Node {
114 pub const fn num_parameters(&self) -> u8 {
116 match self {
117 Node::Leaf(..) | Node::Uniform(..) => 0,
118 Node::Detach(..)
119 | Node::Cast(..)
120 | Node::Neg(..)
121 | Node::ReLU(..)
122 | Node::Exp(..)
123 | Node::Ln(..)
124 | Node::Sin(..)
125 | Node::Cos(..)
126 | Node::Sqrt(..)
127 | Node::Tanh(..)
128 | Node::Reshape(..)
129 | Node::Expand(..)
130 | Node::Permute(..)
131 | Node::Pad(..)
132 | Node::Sum(..)
133 | Node::Max(..) => 1,
134 Node::Add(..)
135 | Node::Sub(..)
136 | Node::Mul(..)
137 | Node::Div(..)
138 | Node::Cmplt(..)
139 | Node::Pow(..) => 2,
140 Node::Where(..) => 3,
141 }
142 }
143
144 pub const fn parameters(&self) -> impl Iterator<Item = Id> {
146 match self {
147 Node::Leaf(..) | Node::Uniform(..) => NodeParametersIterator {
148 parameters: [crate::tensor::id(0); 3],
149 idx: 0,
150 len: 0,
151 },
152 Node::Cast(x, ..)
153 | Node::Detach(x)
154 | Node::Neg(x)
155 | Node::ReLU(x)
156 | Node::Exp(x)
157 | Node::Ln(x)
158 | Node::Sin(x)
159 | Node::Cos(x)
160 | Node::Sqrt(x)
161 | Node::Tanh(x)
162 | Node::Reshape(x, ..)
163 | Node::Expand(x, ..)
164 | Node::Permute(x, ..)
165 | Node::Pad(x, ..)
166 | Node::Sum(x, ..)
167 | Node::Max(x, ..) => NodeParametersIterator {
168 parameters: [*x, crate::tensor::id(0), crate::tensor::id(0)],
169 idx: 0,
170 len: 1,
171 },
172 Node::Add(x, y)
173 | Node::Sub(x, y)
174 | Node::Mul(x, y)
175 | Node::Div(x, y)
176 | Node::Cmplt(x, y)
177 | Node::Pow(x, y) => NodeParametersIterator {
178 parameters: [*x, *y, crate::tensor::id(0)],
179 idx: 0,
180 len: 2,
181 },
182 Node::Where(x, y, z) => NodeParametersIterator {
183 parameters: [*x, *y, *z],
184 idx: 0,
185 len: 3,
186 },
187 }
188 }
189
190 pub fn flop(&self, nodes: &[Node]) -> usize {
192 match self {
193 Node::Detach(..)
194 | Node::Leaf(..)
195 | Node::Uniform(..)
196 | Node::Reshape(..)
197 | Node::Expand(..)
198 | Node::Permute(..)
199 | Node::Pad(..) => 0,
200 Node::Where(x, ..)
201 | Node::Add(x, _)
202 | Node::Sub(x, _)
203 | Node::Mul(x, _)
204 | Node::Div(x, _)
205 | Node::Cmplt(x, _)
206 | Node::Pow(x, _) => get_shape(nodes, *x).numel(), Node::Cast(x, ..)
208 | Node::Neg(x)
209 | Node::ReLU(x)
210 | Node::Exp(x)
211 | Node::Ln(x)
212 | Node::Sin(x)
213 | Node::Cos(x)
214 | Node::Sqrt(x)
215 | Node::Tanh(x) => get_shape(nodes, *x).numel(),
216 Node::Sum(x, _, sh) | Node::Max(x, _, sh) => {
217 let n = sh.numel();
218 let rdim = get_shape(nodes, *x).numel() / n;
219 rdim * n }
221 }
222 }
223
224 pub fn parameters_contain(&self, nid: Id) -> bool {
226 match self {
227 Node::Leaf(..) | Node::Uniform(..) => false,
228 Node::Detach(x)
229 | Node::Cast(x, ..)
230 | Node::Neg(x)
231 | Node::ReLU(x)
232 | Node::Exp(x)
233 | Node::Ln(x)
234 | Node::Sin(x)
235 | Node::Cos(x)
236 | Node::Sqrt(x)
237 | Node::Tanh(x)
238 | Node::Sum(x, ..)
239 | Node::Max(x, ..)
240 | Node::Reshape(x, ..)
241 | Node::Expand(x, ..)
242 | Node::Permute(x, ..)
243 | Node::Pad(x, ..) => nid == *x,
244 Node::Add(x, y)
245 | Node::Sub(x, y)
246 | Node::Mul(x, y)
247 | Node::Div(x, y)
248 | Node::Cmplt(x, y)
249 | Node::Pow(x, y) => nid == *x || nid == *y,
250 Node::Where(x, y, z) => nid == *x || nid == *y || nid == *z,
251 }
252 }
253
254 pub fn is_reduce(&self) -> bool {
256 matches!(self, Node::Sum(..) | Node::Max(..))
257 }
258}