zyx_compiler/ir/
mod.rs

1mod elementwise;
2mod reduce;
3mod work_size;
4
5use crate::ir::work_size::calculate_work_sizes;
6use crate::{ASTBOp, ASTOp, ASTUOp, AST};
7use alloc::{string::String, vec::Vec, collections::BTreeMap};
8use core::fmt::{Display, Formatter};
9use zyx_core::{
10    dtype::DType,
11    view::Index,
12};
13
14/// Variable in IR
15pub enum Var {
16    Local { id: u8, index: String },
17    Register { id: u8, index: Option<String> },
18    ConstF32(f32),
19    ConstF64(f64),
20    ConstI32(i32),
21    ConstI64(i64),
22}
23
24impl Display for Var {
25    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
26        match self {
27            Var::Local { id, index } => f.write_fmt(format_args!("lmem{id}[{index}]")),
28            Var::Register { id, index } => {
29                if let Some(index) = index {
30                    f.write_fmt(format_args!("rmem{id}[{index}]"))
31                } else {
32                    f.write_fmt(format_args!("rmem{id}"))
33                }
34            }
35            Var::ConstF32(value) => f.write_fmt(format_args!("{value:.8}f")),
36            Var::ConstF64(value) => f.write_fmt(format_args!("{value:.16}")),
37            Var::ConstI32(value) => f.write_fmt(format_args!("{value}")),
38            Var::ConstI64(value) => f.write_fmt(format_args!("{value}")),
39        }
40    }
41}
42
43/// Unary op
44pub enum UOp {
45    Noop, // Just assign
46    Cast(DType),
47    Neg,
48    Sin,
49    Cos,
50    Exp,
51    Ln,
52    Tanh,
53    Sqrt,
54}
55
56/// Binary op
57pub enum BOp {
58    Add,
59    Sub,
60    Mul,
61    Div,
62    Pow,
63    Cmplt,
64    Max,
65}
66
67/// Op for the compilers
68pub enum Op {
69    /// Load into res from arg at index
70    LoadGlobal { res: Var, arg: u8, index: Index },
71    /// Store arg into res at index
72    StoreGlobal { res: u8, index: Index, arg: Var },
73    /// Declare register variable id with dtype and optionally
74    /// with given length (as a vector)
75    DeclareVar {
76        dtype: DType,
77        id: u8,
78        len: Option<u8>,
79    },
80    /// Declare local memory variable with id, dtype and given length
81    DeclareLocalVar { id: u8, dtype: DType, len: usize },
82    /// Initialize index id with value
83    InitIndex { id: u8, value: String },
84    /// Declare index
85    DeclareIndex { id: u8 },
86    /// Set index to value
87    SetIndex { id: u8, value: String },
88    /// Initialize accumulator, if sum_reduce is true,
89    /// then initilize to zero, otherwise initilize to minimum
90    /// of dtype.
91    InitAccumulator {
92        id: u8,
93        dtype: DType,
94        is_sum_reduce: bool,
95        len: Option<u8>,
96    },
97    /// Unary op
98    Unary { res: Var, x: Var, op: UOp },
99    /// Binary op
100    Binary { res: Var, x: Var, y: Var, op: BOp },
101    /// Where, x is condition, y is if true, otherwise z
102    Where { res: Var, x: Var, y: Var, z: Var },
103    /// Loop in kernel (register/private)
104    Loop {
105        name: String,
106        upper_bound: usize,
107        step: usize,
108    },
109    /// If condition
110    IfBlock {
111        condition: String,
112    },
113    /// End of if condition
114    EndIf,
115    /// End of loop
116    EndLoop,
117    /// Local memory synchronization
118    LocalBarrier,
119}
120
121/// Intermediate representation for compilers
122pub struct IR {
123    pub global_work_size: Vec<usize>,
124    pub local_work_size: Vec<usize>,
125    pub kernel_args: Vec<(DType, bool)>, // dtype and read_only
126    pub ops: Vec<Op>,
127    pub res_byte_size: usize,
128}
129
130// Single most important function and one of the most difficult
131// functions to write. All of this is cached, so take your time to optimize
132// these kernels.
133pub(super) fn ast_to_ir(ast: &AST, max_local_work_size: usize, max_local_memory_size: usize, max_num_registers: usize) -> IR {
134    // Byte size of the resulting buffer
135    let res_byte_size = if let Some(reduce_axes) = &ast.reduce_axes {
136        ast.shape.clone().reduce(reduce_axes).numel() * ast.dtype.byte_size()
137    } else {
138        ast.shape.numel() * ast.dtype.byte_size()
139    };
140    // TODO we should be able to partially separate optimizations from compilation
141    // TODO Local memory tiling
142    // TODO Register memory tiling
143    // TODO Repeat these functions with different parameters (autotuning)
144    let (
145        arg_views,
146        res_shape,
147        reduce_dim,
148        mut global_work_size,
149        mut local_work_size,
150        register_work_size,
151        tiling_axes,
152    ) = calculate_work_sizes(
153        &ast.reduce_axes,
154        &ast.shape,
155        ast.arg_views.clone(),
156        max_local_work_size,
157        max_num_registers,
158    );
159    let mut kernel_args = Vec::new();
160    for dtype in &ast.arg_dtypes {
161        kernel_args.push((*dtype, true));
162    }
163    // Push result buffer as arg
164    kernel_args.push((ast.dtype, false));
165
166    // Compile ops
167    let ops = if let Some(reduce_dim) = reduce_dim {
168        // Whether it is 1d, 2d or 3d kernel, it can always
169        // have expanded buffers, batches (optionally spread across multiple GPUs)
170        // and multi-step reduces.
171
172        // TODO tiled kernel currently does not work if local work size in dimension 1 and 2 are different
173        if tiling_axes.is_empty() || local_work_size[1] != local_work_size[3] || local_work_size[2] != local_work_size[3] {
174            if global_work_size.iter().product::<usize>() == reduce_dim {
175                // Full reduce
176                // Apply two step reduce
177                let mut d = 1;
178                while global_work_size[3] % (d * 2) == 0 && d < max_local_work_size {
179                    d *= 2;
180                }
181                global_work_size[2] = d;
182                local_work_size[2] = d;
183                reduce::two_step_reduce::compile_reduce_kernel(
184                    &ast.ops,
185                    arg_views,
186                    ast.arg_dtypes.clone(),
187                    ast.reduce_dtype.unwrap(),
188                    reduce_dim,
189                    &local_work_size,
190                    res_shape,
191                )
192            } else {
193                reduce::compile_reduce_kernel(
194                    &ast.ops,
195                    arg_views,
196                    ast.arg_dtypes.clone(),
197                    ast.reduce_dtype.unwrap(),
198                    reduce_dim,
199                    &local_work_size,
200                    res_shape,
201                )
202            }
203        } else {
204            reduce::tiled_reduce::compile_reduce_kernel(
205                &ast.ops,
206                arg_views,
207                ast.arg_dtypes.clone(),
208                ast.reduce_dtype.unwrap(),
209                reduce_dim,
210                &global_work_size,
211                &local_work_size,
212                &register_work_size,
213                res_shape,
214                tiling_axes,
215                max_local_memory_size,
216            )
217        }
218    } else {
219        // TODO make sure that local_work_size is always bigger than register_work_size
220        elementwise::compile_elementwise_kernel(ast, &local_work_size, arg_views, res_shape)
221    };
222
223    IR {
224        global_work_size: if reduce_dim.is_some() { global_work_size[..global_work_size.len()-1].to_vec() } else { global_work_size },
225        local_work_size: if reduce_dim.is_some() { local_work_size[..local_work_size.len()-1].to_vec() } else { local_work_size },
226        kernel_args,
227        ops,
228        res_byte_size,
229    }
230}
231
232// Same op can be applied multiple times with different register_index
233fn apply_elementwise_op(res_id: u8, res_dtype: &mut DType, ast_op: &ASTOp, register_indices: &BTreeMap<u8, String>) -> Vec<Op> {
234    let mut ops = Vec::new();
235    // TODO put all unary ops into single function or probably macro
236    match ast_op {
237        ASTOp::Unary(x, op) => {
238            let mut relu = false;
239            let op = match op {
240                ASTUOp::Cast(dtype) => {
241                    *res_dtype = *dtype;
242                    UOp::Cast(*dtype)
243                }
244                ASTUOp::Neg => UOp::Neg,
245                ASTUOp::ReLU => {
246                    relu = true;
247                    UOp::Neg
248                }
249                ASTUOp::Sin => UOp::Sin,
250                ASTUOp::Cos => UOp::Cos,
251                ASTUOp::Exp => UOp::Exp,
252                ASTUOp::Ln => UOp::Ln,
253                ASTUOp::Tanh => UOp::Tanh,
254                ASTUOp::Sqrt => UOp::Sqrt,
255            };
256            ops.push(Op::DeclareVar {
257                dtype: *res_dtype,
258                id: res_id,
259                len: None,
260            });
261            if relu {
262                ops.push(Op::Binary {
263                    res: Var::Register {
264                        id: res_id,
265                        index: None,
266                    },
267                    x: Var::Register {
268                        id: *x,
269                        index: register_indices.get(x).cloned(),
270                    },
271                    y: match res_dtype {
272                        DType::F32 => Var::ConstF32(0.0),
273                        DType::F64 => Var::ConstF64(0.0),
274                        DType::I32 => Var::ConstI32(0),
275                    },
276                    op: BOp::Max,
277                });
278            } else {
279                ops.push(Op::Unary {
280                    res: Var::Register {
281                        id: res_id,
282                        index: None,
283                    },
284                    x: Var::Register {
285                        id: *x,
286                        index: register_indices.get(x).cloned(),
287                    },
288                    op,
289                });
290            }
291        }
292        ASTOp::Binary(x, y, op) => {
293            ops.push(Op::DeclareVar {
294                dtype: *res_dtype,
295                id: res_id,
296                len: None,
297            });
298            ops.push(Op::Binary {
299                res: Var::Register {
300                    id: res_id,
301                    index: None,
302                },
303                x: Var::Register {
304                    id: *x,
305                    index: register_indices.get(x).cloned(),
306                },
307                y: Var::Register {
308                    id: *y,
309                    index: register_indices.get(y).cloned(),
310                },
311                op: match op {
312                    ASTBOp::Add => BOp::Add,
313                    ASTBOp::Sub => BOp::Sub,
314                    ASTBOp::Mul => BOp::Mul,
315                    ASTBOp::Div => BOp::Div,
316                    ASTBOp::Pow => BOp::Pow,
317                    ASTBOp::Cmplt => BOp::Cmplt,
318                },
319            });
320        }
321        ASTOp::Where(x, y, z) => {
322            ops.push(Op::DeclareVar {
323                dtype: *res_dtype,
324                id: res_id,
325                len: None,
326            });
327            ops.push(Op::Where {
328                res: Var::Register {
329                    id: res_id,
330                    index: None,
331                },
332                x: Var::Register {
333                    id: *x,
334                    index: register_indices.get(x).cloned(),
335                },
336                y: Var::Register {
337                    id: *y,
338                    index: register_indices.get(y).cloned(),
339                },
340                z: Var::Register {
341                    id: *z,
342                    index: register_indices.get(z).cloned(),
343                },
344            });
345        }
346        ASTOp::Leaf(..) | ASTOp::Reduce(..) => {
347            panic!()
348        }
349    }
350    ops
351}