1mod elementwise;
2mod reduce;
3mod work_size;
4
5use crate::ir::work_size::calculate_work_sizes;
6use crate::{ASTBOp, ASTOp, ASTUOp, AST};
7use alloc::{string::String, vec::Vec, collections::BTreeMap};
8use core::fmt::{Display, Formatter};
9use zyx_core::{
10 dtype::DType,
11 view::Index,
12};
13
14pub enum Var {
16 Local { id: u8, index: String },
17 Register { id: u8, index: Option<String> },
18 ConstF32(f32),
19 ConstF64(f64),
20 ConstI32(i32),
21 ConstI64(i64),
22}
23
24impl Display for Var {
25 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
26 match self {
27 Var::Local { id, index } => f.write_fmt(format_args!("lmem{id}[{index}]")),
28 Var::Register { id, index } => {
29 if let Some(index) = index {
30 f.write_fmt(format_args!("rmem{id}[{index}]"))
31 } else {
32 f.write_fmt(format_args!("rmem{id}"))
33 }
34 }
35 Var::ConstF32(value) => f.write_fmt(format_args!("{value:.8}f")),
36 Var::ConstF64(value) => f.write_fmt(format_args!("{value:.16}")),
37 Var::ConstI32(value) => f.write_fmt(format_args!("{value}")),
38 Var::ConstI64(value) => f.write_fmt(format_args!("{value}")),
39 }
40 }
41}
42
43pub enum UOp {
45 Noop, Cast(DType),
47 Neg,
48 Sin,
49 Cos,
50 Exp,
51 Ln,
52 Tanh,
53 Sqrt,
54}
55
56pub enum BOp {
58 Add,
59 Sub,
60 Mul,
61 Div,
62 Pow,
63 Cmplt,
64 Max,
65}
66
67pub enum Op {
69 LoadGlobal { res: Var, arg: u8, index: Index },
71 StoreGlobal { res: u8, index: Index, arg: Var },
73 DeclareVar {
76 dtype: DType,
77 id: u8,
78 len: Option<u8>,
79 },
80 DeclareLocalVar { id: u8, dtype: DType, len: usize },
82 InitIndex { id: u8, value: String },
84 DeclareIndex { id: u8 },
86 SetIndex { id: u8, value: String },
88 InitAccumulator {
92 id: u8,
93 dtype: DType,
94 is_sum_reduce: bool,
95 len: Option<u8>,
96 },
97 Unary { res: Var, x: Var, op: UOp },
99 Binary { res: Var, x: Var, y: Var, op: BOp },
101 Where { res: Var, x: Var, y: Var, z: Var },
103 Loop {
105 name: String,
106 upper_bound: usize,
107 step: usize,
108 },
109 IfBlock {
111 condition: String,
112 },
113 EndIf,
115 EndLoop,
117 LocalBarrier,
119}
120
121pub struct IR {
123 pub global_work_size: Vec<usize>,
124 pub local_work_size: Vec<usize>,
125 pub kernel_args: Vec<(DType, bool)>, pub ops: Vec<Op>,
127 pub res_byte_size: usize,
128}
129
130pub(super) fn ast_to_ir(ast: &AST, max_local_work_size: usize, max_local_memory_size: usize, max_num_registers: usize) -> IR {
134 let res_byte_size = if let Some(reduce_axes) = &ast.reduce_axes {
136 ast.shape.clone().reduce(reduce_axes).numel() * ast.dtype.byte_size()
137 } else {
138 ast.shape.numel() * ast.dtype.byte_size()
139 };
140 let (
145 arg_views,
146 res_shape,
147 reduce_dim,
148 mut global_work_size,
149 mut local_work_size,
150 register_work_size,
151 tiling_axes,
152 ) = calculate_work_sizes(
153 &ast.reduce_axes,
154 &ast.shape,
155 ast.arg_views.clone(),
156 max_local_work_size,
157 max_num_registers,
158 );
159 let mut kernel_args = Vec::new();
160 for dtype in &ast.arg_dtypes {
161 kernel_args.push((*dtype, true));
162 }
163 kernel_args.push((ast.dtype, false));
165
166 let ops = if let Some(reduce_dim) = reduce_dim {
168 if tiling_axes.is_empty() || local_work_size[1] != local_work_size[3] || local_work_size[2] != local_work_size[3] {
174 if global_work_size.iter().product::<usize>() == reduce_dim {
175 let mut d = 1;
178 while global_work_size[3] % (d * 2) == 0 && d < max_local_work_size {
179 d *= 2;
180 }
181 global_work_size[2] = d;
182 local_work_size[2] = d;
183 reduce::two_step_reduce::compile_reduce_kernel(
184 &ast.ops,
185 arg_views,
186 ast.arg_dtypes.clone(),
187 ast.reduce_dtype.unwrap(),
188 reduce_dim,
189 &local_work_size,
190 res_shape,
191 )
192 } else {
193 reduce::compile_reduce_kernel(
194 &ast.ops,
195 arg_views,
196 ast.arg_dtypes.clone(),
197 ast.reduce_dtype.unwrap(),
198 reduce_dim,
199 &local_work_size,
200 res_shape,
201 )
202 }
203 } else {
204 reduce::tiled_reduce::compile_reduce_kernel(
205 &ast.ops,
206 arg_views,
207 ast.arg_dtypes.clone(),
208 ast.reduce_dtype.unwrap(),
209 reduce_dim,
210 &global_work_size,
211 &local_work_size,
212 ®ister_work_size,
213 res_shape,
214 tiling_axes,
215 max_local_memory_size,
216 )
217 }
218 } else {
219 elementwise::compile_elementwise_kernel(ast, &local_work_size, arg_views, res_shape)
221 };
222
223 IR {
224 global_work_size: if reduce_dim.is_some() { global_work_size[..global_work_size.len()-1].to_vec() } else { global_work_size },
225 local_work_size: if reduce_dim.is_some() { local_work_size[..local_work_size.len()-1].to_vec() } else { local_work_size },
226 kernel_args,
227 ops,
228 res_byte_size,
229 }
230}
231
232fn apply_elementwise_op(res_id: u8, res_dtype: &mut DType, ast_op: &ASTOp, register_indices: &BTreeMap<u8, String>) -> Vec<Op> {
234 let mut ops = Vec::new();
235 match ast_op {
237 ASTOp::Unary(x, op) => {
238 let mut relu = false;
239 let op = match op {
240 ASTUOp::Cast(dtype) => {
241 *res_dtype = *dtype;
242 UOp::Cast(*dtype)
243 }
244 ASTUOp::Neg => UOp::Neg,
245 ASTUOp::ReLU => {
246 relu = true;
247 UOp::Neg
248 }
249 ASTUOp::Sin => UOp::Sin,
250 ASTUOp::Cos => UOp::Cos,
251 ASTUOp::Exp => UOp::Exp,
252 ASTUOp::Ln => UOp::Ln,
253 ASTUOp::Tanh => UOp::Tanh,
254 ASTUOp::Sqrt => UOp::Sqrt,
255 };
256 ops.push(Op::DeclareVar {
257 dtype: *res_dtype,
258 id: res_id,
259 len: None,
260 });
261 if relu {
262 ops.push(Op::Binary {
263 res: Var::Register {
264 id: res_id,
265 index: None,
266 },
267 x: Var::Register {
268 id: *x,
269 index: register_indices.get(x).cloned(),
270 },
271 y: match res_dtype {
272 DType::F32 => Var::ConstF32(0.0),
273 DType::F64 => Var::ConstF64(0.0),
274 DType::I32 => Var::ConstI32(0),
275 },
276 op: BOp::Max,
277 });
278 } else {
279 ops.push(Op::Unary {
280 res: Var::Register {
281 id: res_id,
282 index: None,
283 },
284 x: Var::Register {
285 id: *x,
286 index: register_indices.get(x).cloned(),
287 },
288 op,
289 });
290 }
291 }
292 ASTOp::Binary(x, y, op) => {
293 ops.push(Op::DeclareVar {
294 dtype: *res_dtype,
295 id: res_id,
296 len: None,
297 });
298 ops.push(Op::Binary {
299 res: Var::Register {
300 id: res_id,
301 index: None,
302 },
303 x: Var::Register {
304 id: *x,
305 index: register_indices.get(x).cloned(),
306 },
307 y: Var::Register {
308 id: *y,
309 index: register_indices.get(y).cloned(),
310 },
311 op: match op {
312 ASTBOp::Add => BOp::Add,
313 ASTBOp::Sub => BOp::Sub,
314 ASTBOp::Mul => BOp::Mul,
315 ASTBOp::Div => BOp::Div,
316 ASTBOp::Pow => BOp::Pow,
317 ASTBOp::Cmplt => BOp::Cmplt,
318 },
319 });
320 }
321 ASTOp::Where(x, y, z) => {
322 ops.push(Op::DeclareVar {
323 dtype: *res_dtype,
324 id: res_id,
325 len: None,
326 });
327 ops.push(Op::Where {
328 res: Var::Register {
329 id: res_id,
330 index: None,
331 },
332 x: Var::Register {
333 id: *x,
334 index: register_indices.get(x).cloned(),
335 },
336 y: Var::Register {
337 id: *y,
338 index: register_indices.get(y).cloned(),
339 },
340 z: Var::Register {
341 id: *z,
342 index: register_indices.get(z).cloned(),
343 },
344 });
345 }
346 ASTOp::Leaf(..) | ASTOp::Reduce(..) => {
347 panic!()
348 }
349 }
350 ops
351}