1use crate::{ASTOp, CompiledBackend, Compiler, AST, ASTUOp, ASTBOp, ASTROp};
2use alloc::{
3 collections::{btree_map::Entry, BTreeMap},
4 vec::Vec,
5};
6use zyx_core::axes::Axes;
7use zyx_core::dtype::DType;
8use zyx_core::error::ZyxError;
9use zyx_core::node::Node;
10use zyx_core::runtime::RuntimeBackend;
11use zyx_core::scalar::Scalar;
12use zyx_core::shape::Shape;
13use zyx_core::tensor::Id;
14use zyx_core::utils::get_dtype;
15use zyx_core::view::View;
16
17#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)]
18pub(super) struct Kernel {
19 program_args: Vec<Id>,
20 arg_views: Vec<View>,
21 arg_dtypes: Vec<DType>,
22 ops: Vec<ASTOp>,
23 reduce_axes: Option<Axes>,
24 reduce_dtype: Option<DType>,
25 shape: Shape,
26 dtype: DType,
27 flop: usize,
28 bytes: usize,
29}
30
31impl Kernel {
32 fn leaf(x: Id, shape: &Shape, dtype: &DType) -> Self {
33 Self {
34 program_args: alloc::vec![x],
35 arg_views: alloc::vec![View::new(shape.clone())],
36 arg_dtypes: alloc::vec![dtype.clone()],
37 ops: alloc::vec![ASTOp::Leaf(0)],
38 reduce_axes: None,
39 reduce_dtype: None,
40 shape: shape.clone(),
41 dtype: *dtype,
42 flop: 0,
43 bytes: shape.numel() * dtype.byte_size(),
44 }
45 }
46}
47
48impl<C: Compiler> RuntimeBackend for CompiledBackend<C> {
49 fn is_evaluated(&self, x: Id) -> bool {
50 self.kernels.contains_key(&x)
51 }
52
53 fn is_free_id(&self, x: Id) -> bool {
54 !(self.buffers.contains_key(&x) || self.kernels.contains_key(&x))
55 }
56
57 fn remove(&mut self, x: Id) -> Result<(), ZyxError> {
58 if let Some(Kernel { program_args, .. }) = self.kernels.remove(&x) {
59 for p in program_args.iter().chain([&x]) {
60 if !self
61 .kernels
62 .values()
63 .any(|k| k.program_args.contains(&p))
64 {
65 if let Some(mut buffer) = self.buffers.remove(&p) {
66 self.compiler.drop_buffer(&mut buffer)?;
68 }
69 }
70 }
71 }
72 Ok(())
73 }
74
75 fn store<T: Scalar, IT>(&mut self, x: Id, iter: IT) -> Result<(), ZyxError>
76 where
77 IT: IntoIterator<Item = T>,
78 IT::IntoIter: ExactSizeIterator,
79 {
80 let iter = iter.into_iter();
82 self.kernels
83 .insert(x, Kernel::leaf(x, &iter.len().into(), &T::dtype()));
84 self.buffers.insert(x, self.compiler.store(iter)?);
85 Ok(())
86 }
87
88 fn load<T: Scalar>(&mut self, x: Id, numel: usize) -> Result<Vec<T>, ZyxError> {
89 if let Some(buffer) = self.buffers.get(&x) {
91 self.compiler.load(buffer, numel)
92 } else {
93 self.evaluate_kernel(x)?;
94 self.compiler.load(&self.buffers[&x], numel)
95 }
96 }
97
98 fn evaluate(
99 &mut self,
100 mut rcs: BTreeMap<Id, u32>,
101 order: &[Id],
102 nodes: &[Node],
103 ) -> Result<(), ZyxError> {
104 for nid in order.iter().copied() {
108 let mut kernel = match &nodes[nid.i()] {
110 Node::Leaf(sh, dtype) => Kernel::leaf(nid, sh, dtype),
111 Node::Uniform(..) => {
112 todo!()
113 }
114 Node::Cast(x, dtype) => {
115 let mut buffer = self.kernels[&x].clone();
116 buffer
117 .ops
118 .push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Cast(*dtype)));
119 buffer.dtype = *dtype;
120 buffer
121 }
122 Node::Detach(x) => self.kernels[&x].clone(),
123 Node::Neg(x) => {
124 let mut buffer = self.kernels[&x].clone();
125 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Neg));
126 buffer
127 }
128 Node::ReLU(x) => {
129 let mut buffer = self.kernels[&x].clone();
130 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::ReLU));
131 buffer
132 }
133 Node::Exp(x) => {
134 let mut buffer = self.kernels[&x].clone();
135 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Exp));
136 buffer
137 }
138 Node::Ln(x) => {
139 let mut buffer = self.kernels[&x].clone();
140 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Ln));
141 buffer
142 }
143 Node::Sin(x) => {
144 let mut buffer = self.kernels[&x].clone();
145 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Sin));
146 buffer
147 }
148 Node::Cos(x) => {
149 let mut buffer = self.kernels[&x].clone();
150 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Cos));
151 buffer
152 }
153 Node::Sqrt(x) => {
154 let mut buffer = self.kernels[&x].clone();
155 buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Sqrt));
156 buffer
157 }
158 Node::Tanh(x) => {
159 let mut kernel = self.kernels[&x].clone();
160 kernel.ops.push(ASTOp::Unary(kernel.ops.len() as u8 - 1, ASTUOp::Tanh));
161 kernel
162 }
163 Node::Add(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Add))?,
164 Node::Sub(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Sub))?,
165 Node::Mul(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Mul))?,
166 Node::Div(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Div))?,
167 Node::Pow(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Pow))?,
168 Node::Cmplt(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Cmplt))?,
169 Node::Where(..) => {
170 todo!()
172 }
173 Node::Reshape(x, sh) => {
174 let mut buffer = if self.kernels[&x].reduce_axes.is_some() {
175 self.evaluate_kernel(*x)?.clone()
181 } else {
182 self.kernels[&x].clone()
183 };
184 for view in &mut buffer.arg_views {
185 *view = view.reshape(sh);
186 }
187 buffer.shape = sh.clone();
188 buffer
189 }
190 Node::Expand(x, sh) => {
191 let mut kernel = if self.kernels[&x].reduce_axes.is_some() {
192 self.evaluate_kernel(*x)?.clone()
193 } else {
194 self.kernels[&x].clone()
195 };
196 for view in &mut kernel.arg_views {
197 *view = view.expand(sh);
198 }
199 kernel.shape = sh.clone();
200 kernel
201 }
202 Node::Permute(x, ax, sh) => {
203 let mut kernel = self.kernels[&x].clone();
204 for view in &mut kernel.arg_views {
205 *view = view.permute(ax);
206 }
207 if let Some(reduce_axes) = &mut kernel.reduce_axes {
208 *reduce_axes = reduce_axes.permute(ax);
209 }
210 kernel.shape = sh.clone();
211 kernel
212 }
213 Node::Pad(x, padding, sh) => {
214 let mut kernel = if self.kernels[&x].reduce_axes.is_some() {
215 self.evaluate_kernel(*x)?.clone()
216 } else {
217 self.kernels[&x].clone()
218 };
219 for view in &mut kernel.arg_views {
220 *view = view.pad(padding);
221 }
222 kernel.shape = sh.clone();
223 kernel
224 }
225 Node::Sum(x, ax, _) => {
226 let mut kernel = self.kernels[&x].clone();
227 if kernel.reduce_axes.is_some() {
228 kernel = self.evaluate_kernel(*x)?.clone();
229 kernel.reduce_axes = Some(ax.clone());
230 kernel.reduce_dtype = Some(get_dtype(nodes, nid));
231 kernel.ops.push(ASTOp::Reduce(0, ASTROp::Sum));
232 } else {
233 kernel.reduce_axes = Some(ax.clone());
234 kernel.reduce_dtype = Some(get_dtype(nodes, nid));
235 kernel.ops.push(ASTOp::Reduce(kernel.ops.len() as u8 - 1, ASTROp::Sum));
236 }
237 kernel
238 }
239 Node::Max(x, ax, _) => {
240 let mut kernel = self.kernels[&x].clone();
241 if kernel.reduce_axes.is_some() {
242 kernel = self.evaluate_kernel(*x)?.clone();
243 kernel.reduce_axes = Some(ax.clone());
244 kernel.reduce_dtype = Some(get_dtype(nodes, nid));
245 kernel.ops.push(ASTOp::Reduce(0, ASTROp::Max));
246 } else {
247 kernel.reduce_axes = Some(ax.clone());
248 kernel.reduce_dtype = Some(get_dtype(nodes, nid));
249 kernel.ops.push(ASTOp::Reduce(kernel.ops.len() as u8 - 1, ASTROp::Max));
250 }
251 kernel
252 }
253 };
254 kernel.flop += nodes[nid.i()].flop(&nodes);
255 self.kernels.insert(nid, kernel);
257
258 if self.kernels[&nid].ops.len() > 200
259 || (rcs[&nid] > 1 && self.kernels[&nid].program_args.len() > 1)
260 {
261 self.evaluate_kernel(nid)?;
263 }
264 for p in nodes[nid.i()].parameters() {
267 if let Entry::Occupied(e) = rcs.entry(p).and_modify(|rc| *rc -= 1) {
268 if *e.get() == 0 {
269 self.remove(p)?;
270 }
271 }
272 }
273 }
274 Ok(())
275 }
276}
277
278impl<C: Compiler> CompiledBackend<C> {
279 pub fn new(compiler: C) -> Self {
281 Self {
282 compiler,
283 kernels: BTreeMap::new(),
284 buffers: BTreeMap::new(),
285 programs: BTreeMap::new(),
286 }
287 }
288
289 fn evaluate_kernel(&mut self, x: Id) -> Result<&Kernel, ZyxError> {
290 if self.buffers.contains_key(&x) {
292 return Ok(&self.kernels[&x]);
294 }
295 let Kernel {
296 program_args,
297 arg_views,
298 arg_dtypes,
299 ops,
300 reduce_axes,
301 reduce_dtype,
302 shape,
303 dtype,
304 flop,
305 bytes,
306 } = self.kernels[&x].clone();
307 let r_shape = if let Some(reduce_axes) = &reduce_axes {
308 shape.clone().reduce(reduce_axes)
309 } else {
310 shape.clone()
311 };
312 let ast = AST {
313 arg_views,
314 arg_dtypes,
315 ops,
316 shape,
317 dtype,
318 reduce_axes,
319 reduce_dtype,
320 };
321 let program = if let Some(program) = self.programs.get(&ast) {
325 program
326 } else {
327 let ir = crate::ir::ast_to_ir(&ast, 256, 256*1024*8, 64);
330 let program = self.compiler.compile(&ir)?;
331 self.programs.entry(ast).or_insert(program)
332 };
333 let program_args: Vec<&C::Buffer> = program_args
334 .into_iter()
335 .map(|nid| &self.buffers[&nid])
336 .collect();
337 self.buffers.insert(
339 x,
340 self.compiler.launch(program, &program_args, flop, bytes)?,
341 );
342
343 if let Some(kernel) = self.kernels.insert(x, Kernel::leaf(x, &r_shape, &dtype)) {
345 for p in kernel.program_args {
346 if !self
347 .kernels
348 .values()
349 .any(|k| k.program_args.contains(&p))
350 {
351 if let Some(mut buffer) = self.buffers.remove(&p) {
352 self.compiler.drop_buffer(&mut buffer)?;
354 }
355 }
356 }
357 }
358 Ok(&self.kernels[&x])
359 }
360
361 fn binary_kernel(
362 &mut self,
363 x: Id,
364 y: Id,
365 op: impl Fn(u8, u8) -> ASTOp,
366 ) -> Result<Kernel, ZyxError> {
367 let (reduce_axes, reduce_dtype) = if x != y {
368 match (
369 self.kernels[&x].reduce_axes.clone(),
370 self.kernels[&y].reduce_axes.clone(),
371 ) {
372 (Some(x_ax), Some(_)) => {
373 self.evaluate_kernel(y)?;
374 (Some(x_ax), Some(self.kernels[&x].dtype))
375 }
376 (Some(x_ax), None) => (Some(x_ax), Some(self.kernels[&x].dtype)),
377 (None, Some(y_ax)) => (Some(y_ax), Some(self.kernels[&y].dtype)),
378 (None, None) => (None, None),
379 }
380 } else {
381 let mut buffer = if self.kernels[&x].reduce_axes.is_some() {
384 self.evaluate_kernel(x)?.clone()
385 } else {
386 self.kernels[&x].clone()
387 };
388 let n = buffer.ops.len() as u8 - 1;
389 buffer.ops.push(op(n, n));
390 return Ok(buffer);
392 };
393 let x_buffer = &self.kernels[&x];
394 let y_buffer = &self.kernels[&y];
395 let n = x_buffer.ops.len() as u8;
396 Ok(Kernel {
397 program_args: x_buffer
398 .program_args
399 .iter()
400 .chain(y_buffer.program_args.iter())
401 .copied()
402 .collect(),
403 arg_views: x_buffer
404 .arg_views
405 .iter()
406 .chain(y_buffer.arg_views.iter())
407 .cloned()
408 .collect(),
409 arg_dtypes: x_buffer
410 .arg_dtypes
411 .iter()
412 .chain(y_buffer.arg_dtypes.iter())
413 .copied()
414 .collect(),
415 ops: x_buffer
416 .ops
417 .iter()
418 .cloned()
419 .chain(y_buffer.ops.iter().cloned().map(|mut op| {
420 match &mut op {
421 ASTOp::Leaf(x) => *x += x_buffer.arg_views.len() as u8,
422 ASTOp::Unary(x, ..) | ASTOp::Reduce(x, ..) => *x += n,
423 ASTOp::Binary(x, y, ..) => {
424 *x += n;
425 *y += n;
426 }
427 ASTOp::Where(x, y, z) => {
428 *x += n;
429 *y += n;
430 *z += n;
431 }
432 }
433 op
434 }))
435 .chain([op(n - 1, n + y_buffer.ops.len() as u8 - 1)])
436 .collect(),
437 reduce_axes,
438 reduce_dtype,
439 shape: x_buffer.shape.clone(),
440 dtype: x_buffer.dtype,
441 flop: x_buffer.flop + y_buffer.flop,
442 bytes: x_buffer.bytes + y_buffer.bytes,
443 })
444 }
445}