zyx_compiler/
ast.rs

1use crate::{ASTOp, CompiledBackend, Compiler, AST, ASTUOp, ASTBOp, ASTROp};
2use alloc::{
3    collections::{btree_map::Entry, BTreeMap},
4    vec::Vec,
5};
6use zyx_core::axes::Axes;
7use zyx_core::dtype::DType;
8use zyx_core::error::ZyxError;
9use zyx_core::node::Node;
10use zyx_core::runtime::RuntimeBackend;
11use zyx_core::scalar::Scalar;
12use zyx_core::shape::Shape;
13use zyx_core::tensor::Id;
14use zyx_core::utils::get_dtype;
15use zyx_core::view::View;
16
17#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)]
18pub(super) struct Kernel {
19    program_args: Vec<Id>,
20    arg_views: Vec<View>,
21    arg_dtypes: Vec<DType>,
22    ops: Vec<ASTOp>,
23    reduce_axes: Option<Axes>,
24    reduce_dtype: Option<DType>,
25    shape: Shape,
26    dtype: DType,
27    flop: usize,
28    bytes: usize,
29}
30
31impl Kernel {
32    fn leaf(x: Id, shape: &Shape, dtype: &DType) -> Self {
33        Self {
34            program_args: alloc::vec![x],
35            arg_views: alloc::vec![View::new(shape.clone())],
36            arg_dtypes: alloc::vec![dtype.clone()],
37            ops: alloc::vec![ASTOp::Leaf(0)],
38            reduce_axes: None,
39            reduce_dtype: None,
40            shape: shape.clone(),
41            dtype: *dtype,
42            flop: 0,
43            bytes: shape.numel() * dtype.byte_size(),
44        }
45    }
46}
47
48impl<C: Compiler> RuntimeBackend for CompiledBackend<C> {
49    fn is_evaluated(&self, x: Id) -> bool {
50        self.kernels.contains_key(&x)
51    }
52
53    fn is_free_id(&self, x: Id) -> bool {
54        !(self.buffers.contains_key(&x) || self.kernels.contains_key(&x))
55    }
56
57    fn remove(&mut self, x: Id) -> Result<(), ZyxError> {
58        if let Some(Kernel { program_args, .. }) = self.kernels.remove(&x) {
59            for p in program_args.iter().chain([&x]) {
60                if !self
61                    .kernels
62                    .values()
63                    .any(|k| k.program_args.contains(&p))
64                {
65                    if let Some(mut buffer) = self.buffers.remove(&p) {
66                        //std::println!("Dropping buffer {p} out of total {} buffers", self.buffers.len());
67                        self.compiler.drop_buffer(&mut buffer)?;
68                    }
69                }
70            }
71        }
72        Ok(())
73    }
74
75    fn store<T: Scalar, IT>(&mut self, x: Id, iter: IT) -> Result<(), ZyxError>
76    where
77        IT: IntoIterator<Item = T>,
78        IT::IntoIter: ExactSizeIterator,
79    {
80        //std::println!("Storing {x}");
81        let iter = iter.into_iter();
82        self.kernels
83            .insert(x, Kernel::leaf(x, &iter.len().into(), &T::dtype()));
84        self.buffers.insert(x, self.compiler.store(iter)?);
85        Ok(())
86    }
87
88    fn load<T: Scalar>(&mut self, x: Id, numel: usize) -> Result<Vec<T>, ZyxError> {
89        //std::println!("Loading {x}");
90        if let Some(buffer) = self.buffers.get(&x) {
91            self.compiler.load(buffer, numel)
92        } else {
93            self.evaluate_kernel(x)?;
94            self.compiler.load(&self.buffers[&x], numel)
95        }
96    }
97
98    fn evaluate(
99        &mut self,
100        mut rcs: BTreeMap<Id, u32>,
101        order: &[Id],
102        nodes: &[Node],
103    ) -> Result<(), ZyxError> {
104        //std::println!("Evaluating rcs {:?}", rcs);
105        // TODO must_eval are currently new_leafs from runtime, but this may not
106        // be the case later and then this won't work, so fix it.
107        for nid in order.iter().copied() {
108            //std::println!("Compiling {nid}: {:?} x {}", nodes[nid.i()], rcs[&nid]);
109            let mut kernel = match &nodes[nid.i()] {
110                Node::Leaf(sh, dtype) => Kernel::leaf(nid, sh, dtype),
111                Node::Uniform(..) => {
112                    todo!()
113                }
114                Node::Cast(x, dtype) => {
115                    let mut buffer = self.kernels[&x].clone();
116                    buffer
117                        .ops
118                        .push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Cast(*dtype)));
119                    buffer.dtype = *dtype;
120                    buffer
121                }
122                Node::Detach(x) => self.kernels[&x].clone(),
123                Node::Neg(x) => {
124                    let mut buffer = self.kernels[&x].clone();
125                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Neg));
126                    buffer
127                }
128                Node::ReLU(x) => {
129                    let mut buffer = self.kernels[&x].clone();
130                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::ReLU));
131                    buffer
132                }
133                Node::Exp(x) => {
134                    let mut buffer = self.kernels[&x].clone();
135                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Exp));
136                    buffer
137                }
138                Node::Ln(x) => {
139                    let mut buffer = self.kernels[&x].clone();
140                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Ln));
141                    buffer
142                }
143                Node::Sin(x) => {
144                    let mut buffer = self.kernels[&x].clone();
145                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Sin));
146                    buffer
147                }
148                Node::Cos(x) => {
149                    let mut buffer = self.kernels[&x].clone();
150                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Cos));
151                    buffer
152                }
153                Node::Sqrt(x) => {
154                    let mut buffer = self.kernels[&x].clone();
155                    buffer.ops.push(ASTOp::Unary(buffer.ops.len() as u8 - 1, ASTUOp::Sqrt));
156                    buffer
157                }
158                Node::Tanh(x) => {
159                    let mut kernel = self.kernels[&x].clone();
160                    kernel.ops.push(ASTOp::Unary(kernel.ops.len() as u8 - 1, ASTUOp::Tanh));
161                    kernel
162                }
163                Node::Add(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Add))?,
164                Node::Sub(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Sub))?,
165                Node::Mul(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Mul))?,
166                Node::Div(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Div))?,
167                Node::Pow(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Pow))?,
168                Node::Cmplt(x, y) => self.binary_kernel(*x, *y, |x, y| ASTOp::Binary(x, y, ASTBOp::Cmplt))?,
169                Node::Where(..) => {
170                    // TODO fix this for x == y == z or any combination of those
171                    todo!()
172                }
173                Node::Reshape(x, sh) => {
174                    let mut buffer = if self.kernels[&x].reduce_axes.is_some() {
175                        // TODO this should not always evaluate, because dot reshapes
176                        // result and then we should still be able to merge it with unary
177                        // and binary ops.
178                        // TODO reshape can be applied, but no transpose afterwards? Perhaps?
179                        // TODO it is perhaps easier to just reorder reshape to come later
180                        self.evaluate_kernel(*x)?.clone()
181                    } else {
182                        self.kernels[&x].clone()
183                    };
184                    for view in &mut buffer.arg_views {
185                        *view = view.reshape(sh);
186                    }
187                    buffer.shape = sh.clone();
188                    buffer
189                }
190                Node::Expand(x, sh) => {
191                    let mut kernel = if self.kernels[&x].reduce_axes.is_some() {
192                        self.evaluate_kernel(*x)?.clone()
193                    } else {
194                        self.kernels[&x].clone()
195                    };
196                    for view in &mut kernel.arg_views {
197                        *view = view.expand(sh);
198                    }
199                    kernel.shape = sh.clone();
200                    kernel
201                }
202                Node::Permute(x, ax, sh) => {
203                    let mut kernel = self.kernels[&x].clone();
204                    for view in &mut kernel.arg_views {
205                        *view = view.permute(ax);
206                    }
207                    if let Some(reduce_axes) = &mut kernel.reduce_axes {
208                        *reduce_axes = reduce_axes.permute(ax);
209                    }
210                    kernel.shape = sh.clone();
211                    kernel
212                }
213                Node::Pad(x, padding, sh) => {
214                    let mut kernel = if self.kernels[&x].reduce_axes.is_some() {
215                        self.evaluate_kernel(*x)?.clone()
216                    } else {
217                        self.kernels[&x].clone()
218                    };
219                    for view in &mut kernel.arg_views {
220                        *view = view.pad(padding);
221                    }
222                    kernel.shape = sh.clone();
223                    kernel
224                }
225                Node::Sum(x, ax, _) => {
226                    let mut kernel = self.kernels[&x].clone();
227                    if kernel.reduce_axes.is_some() {
228                        kernel = self.evaluate_kernel(*x)?.clone();
229                        kernel.reduce_axes = Some(ax.clone());
230                        kernel.reduce_dtype = Some(get_dtype(nodes, nid));
231                        kernel.ops.push(ASTOp::Reduce(0, ASTROp::Sum));
232                    } else {
233                        kernel.reduce_axes = Some(ax.clone());
234                        kernel.reduce_dtype = Some(get_dtype(nodes, nid));
235                        kernel.ops.push(ASTOp::Reduce(kernel.ops.len() as u8 - 1, ASTROp::Sum));
236                    }
237                    kernel
238                }
239                Node::Max(x, ax, _) => {
240                    let mut kernel = self.kernels[&x].clone();
241                    if kernel.reduce_axes.is_some() {
242                        kernel = self.evaluate_kernel(*x)?.clone();
243                        kernel.reduce_axes = Some(ax.clone());
244                        kernel.reduce_dtype = Some(get_dtype(nodes, nid));
245                        kernel.ops.push(ASTOp::Reduce(0, ASTROp::Max));
246                    } else {
247                        kernel.reduce_axes = Some(ax.clone());
248                        kernel.reduce_dtype = Some(get_dtype(nodes, nid));
249                        kernel.ops.push(ASTOp::Reduce(kernel.ops.len() as u8 - 1, ASTROp::Max));
250                    }
251                    kernel
252                }
253            };
254            kernel.flop += nodes[nid.i()].flop(&nodes);
255            //std::println!("Inserting kernel {nid}");
256            self.kernels.insert(nid, kernel);
257
258            if self.kernels[&nid].ops.len() > 200
259                || (rcs[&nid] > 1 && self.kernels[&nid].program_args.len() > 1)
260            {
261                //std::println!("Forcing evaluation of {nid}");
262                self.evaluate_kernel(nid)?;
263            }
264            //std::println!("Kernel {:?}, len of buffers: {}", self.kernels[&nid], self.buffers.len());
265
266            for p in nodes[nid.i()].parameters() {
267                if let Entry::Occupied(e) = rcs.entry(p).and_modify(|rc| *rc -= 1) {
268                    if *e.get() == 0 {
269                        self.remove(p)?;
270                    }
271                }
272            }
273        }
274        Ok(())
275    }
276}
277
278impl<C: Compiler> CompiledBackend<C> {
279    /// Initialize new compiled backend using provided compiler
280    pub fn new(compiler: C) -> Self {
281        Self {
282            compiler,
283            kernels: BTreeMap::new(),
284            buffers: BTreeMap::new(),
285            programs: BTreeMap::new(),
286        }
287    }
288
289    fn evaluate_kernel(&mut self, x: Id) -> Result<&Kernel, ZyxError> {
290        //kstd::println!("Evaluating kernel {x}");
291        if self.buffers.contains_key(&x) {
292            //std::println!("Accessing kernel {x}, {:?} {:?}", self.buffers.keys(), self.kernels);
293            return Ok(&self.kernels[&x]);
294        }
295        let Kernel {
296            program_args,
297            arg_views,
298            arg_dtypes,
299            ops,
300            reduce_axes,
301            reduce_dtype,
302            shape,
303            dtype,
304            flop,
305            bytes,
306        } = self.kernels[&x].clone();
307        let r_shape = if let Some(reduce_axes) = &reduce_axes {
308            shape.clone().reduce(reduce_axes)
309        } else {
310            shape.clone()
311        };
312        let ast = AST {
313            arg_views,
314            arg_dtypes,
315            ops,
316            shape,
317            dtype,
318            reduce_axes,
319            reduce_dtype,
320        };
321        //std::println!("Ops");
322        //for op in &*ast.ops { std::println!("{op:?}"); }
323        // Used cached program or compile new program
324        let program = if let Some(program) = self.programs.get(&ast) {
325            program
326        } else {
327            // TODO optimize ast as much as possible here
328            // for example deduplicate args
329            let ir = crate::ir::ast_to_ir(&ast, 256, 256*1024*8, 64);
330            let program = self.compiler.compile(&ir)?;
331            self.programs.entry(ast).or_insert(program)
332        };
333        let program_args: Vec<&C::Buffer> = program_args
334            .into_iter()
335            .map(|nid| &self.buffers[&nid])
336            .collect();
337        // Run the program
338        self.buffers.insert(
339            x,
340            self.compiler.launch(program, &program_args, flop, bytes)?,
341        );
342
343        // We need to remove unused kernel and possibly drop its args!
344        if let Some(kernel) = self.kernels.insert(x, Kernel::leaf(x, &r_shape, &dtype)) {
345            for p in kernel.program_args {
346                if !self
347                    .kernels
348                    .values()
349                    .any(|k| k.program_args.contains(&p))
350                {
351                    if let Some(mut buffer) = self.buffers.remove(&p) {
352                        //std::println!("Dropping buffer {p} out of total {} buffers", self.buffers.len());
353                        self.compiler.drop_buffer(&mut buffer)?;
354                    }
355                }
356            }
357        }
358        Ok(&self.kernels[&x])
359    }
360
361    fn binary_kernel(
362        &mut self,
363        x: Id,
364        y: Id,
365        op: impl Fn(u8, u8) -> ASTOp,
366    ) -> Result<Kernel, ZyxError> {
367        let (reduce_axes, reduce_dtype) = if x != y {
368            match (
369                self.kernels[&x].reduce_axes.clone(),
370                self.kernels[&y].reduce_axes.clone(),
371            ) {
372                (Some(x_ax), Some(_)) => {
373                    self.evaluate_kernel(y)?;
374                    (Some(x_ax), Some(self.kernels[&x].dtype))
375                }
376                (Some(x_ax), None) => (Some(x_ax), Some(self.kernels[&x].dtype)),
377                (None, Some(y_ax)) => (Some(y_ax), Some(self.kernels[&y].dtype)),
378                (None, None) => (None, None),
379            }
380        } else {
381            //(self.kernels[&x].reduce_axes.clone(), self.kernels[&x].reduce_dtype)
382            // TODO if it is one reduce, it could be merged, but it is lot of work
383            let mut buffer = if self.kernels[&x].reduce_axes.is_some() {
384                self.evaluate_kernel(x)?.clone()
385            } else {
386                self.kernels[&x].clone()
387            };
388            let n = buffer.ops.len() as u8 - 1;
389            buffer.ops.push(op(n, n));
390            //(None, None)
391            return Ok(buffer);
392        };
393        let x_buffer = &self.kernels[&x];
394        let y_buffer = &self.kernels[&y];
395        let n = x_buffer.ops.len() as u8;
396        Ok(Kernel {
397            program_args: x_buffer
398                .program_args
399                .iter()
400                .chain(y_buffer.program_args.iter())
401                .copied()
402                .collect(),
403            arg_views: x_buffer
404                .arg_views
405                .iter()
406                .chain(y_buffer.arg_views.iter())
407                .cloned()
408                .collect(),
409            arg_dtypes: x_buffer
410                .arg_dtypes
411                .iter()
412                .chain(y_buffer.arg_dtypes.iter())
413                .copied()
414                .collect(),
415            ops: x_buffer
416                .ops
417                .iter()
418                .cloned()
419                .chain(y_buffer.ops.iter().cloned().map(|mut op| {
420                    match &mut op {
421                        ASTOp::Leaf(x) => *x += x_buffer.arg_views.len() as u8,
422                        ASTOp::Unary(x, ..) | ASTOp::Reduce(x, ..) => *x += n,
423                        ASTOp::Binary(x, y, ..) => {
424                            *x += n;
425                            *y += n;
426                        }
427                        ASTOp::Where(x, y, z) => {
428                            *x += n;
429                            *y += n;
430                            *z += n;
431                        }
432                    }
433                    op
434                }))
435                .chain([op(n - 1, n + y_buffer.ops.len() as u8 - 1)])
436                .collect(),
437            reduce_axes,
438            reduce_dtype,
439            shape: x_buffer.shape.clone(),
440            dtype: x_buffer.dtype,
441            flop: x_buffer.flop + y_buffer.flop,
442            bytes: x_buffer.bytes + y_buffer.bytes,
443        })
444    }
445}