ug/
schedule.rs

1use crate::lang::op::{ArgId, Ast};
2use crate::{Device, Layout, LazyBuffer, Result};
3use std::collections::HashMap;
4
5type Args<D> = Vec<(ArgId, LazyBuffer<D>)>;
6
7pub struct KernelItem<D: Device> {
8    ast: Ast,
9    dst: (ArgId, LazyBuffer<D>),
10    dst_layout: Option<Layout>,
11    args: Args<D>,
12}
13
14impl<D: Device> KernelItem<D> {
15    pub fn into_ast(self) -> Ast {
16        self.ast
17    }
18
19    pub fn ast(&self) -> &Ast {
20        &self.ast
21    }
22
23    pub fn kernel(&self) -> Result<crate::lang::op::Kernel> {
24        use crate::lang::op;
25        let ast = &self.ast;
26        let dst = &self.dst;
27        let args = self
28            .args
29            .iter()
30            .map(|(id, lb)| op::Arg::new(*id, crate::lang::Type::Ptr(lb.dtype())))
31            .collect::<Vec<_>>();
32        let dst_layout = match &self.dst_layout {
33            None => Layout::from_shape(dst.1.shape()),
34            Some(l) => l.clone(),
35        };
36        let sto = op::store(dst.0, dst_layout, ast.clone())?;
37        let kernel = op::Kernel::new(format!("realize_{:?}", dst.1.id()), args, vec![sto]);
38        Ok(kernel)
39    }
40}
41
42pub enum ScheduleItem<D: Device> {
43    Kernel(KernelItem<D>),
44    MatMul {
45        dst: LazyBuffer<D>,
46        lhs: LazyBuffer<D>,
47        rhs: LazyBuffer<D>,
48        bmnk: (usize, usize, usize, usize),
49        transpose: bool,
50    },
51    Custom {
52        f: crate::lazy_buffer::CustomF<D::Slice>,
53        args: Args<D>,
54    },
55    Ssa {
56        ssa: crate::lang::ssa::Kernel,
57        args: Args<D>,
58    },
59}
60
61pub struct Schedule<D: Device> {
62    /// Elements in `items` are topologically sorted so that they can be run in order.
63    items: Vec<ScheduleItem<D>>,
64    per_arg_id: HashMap<ArgId, LazyBuffer<D>>,
65    span_compile: tracing::Span,
66    span_kernel: tracing::Span,
67    device: D,
68}
69
70impl<D: Device> Schedule<D> {
71    pub fn get_arg_id(&self, arg_id: ArgId) -> Result<&LazyBuffer<D>> {
72        match self.per_arg_id.get(&arg_id) {
73            Some(b) => Ok(b),
74            None => crate::bail!("no arg for id {arg_id:?}"),
75        }
76    }
77
78    pub fn create(buffers: &[&LazyBuffer<D>]) -> Result<Self> {
79        let device = if buffers.is_empty() {
80            crate::bail!("no buffers provided")
81        } else {
82            buffers[0].device().clone()
83        };
84        let mut cnts = HashMap::new();
85        for buffer in buffers.iter() {
86            id_cnts(buffer, &mut cnts)?
87        }
88        let mut context = Context::new(cnts);
89        for &buffer in buffers.iter() {
90            context.push_schedule_item(buffer)?;
91        }
92        let span_compile = tracing::span!(tracing::Level::TRACE, "compile");
93        let span_kernel = tracing::span!(tracing::Level::TRACE, "kernel");
94        Ok(Self {
95            items: context.items,
96            device,
97            per_arg_id: context.per_arg_id,
98            span_compile,
99            span_kernel,
100        })
101    }
102
103    pub fn create_one(buffer: &LazyBuffer<D>) -> Result<Self> {
104        Self::create(&[buffer])
105    }
106
107    pub fn items(&self) -> &[ScheduleItem<D>] {
108        self.items.as_slice()
109    }
110
111    pub fn compile(&self) -> Result<CompiledSchedule<D>> {
112        self.compile_with_cache(&mut Default::default())
113    }
114
115    pub fn compile_with_cache(
116        &self,
117        compilation_cache: &mut crate::cache::CompilationCache<D>,
118    ) -> Result<CompiledSchedule<D>> {
119        let _guard = self.span_compile.enter();
120        let mut funcs = Vec::with_capacity(self.items().len());
121        for item in self.items() {
122            let call = match item {
123                ScheduleItem::MatMul { dst, lhs, rhs, bmnk, transpose } => Func::MatMul {
124                    dst: dst.clone(),
125                    lhs: lhs.clone(),
126                    rhs: rhs.clone(),
127                    bmnk: *bmnk,
128                    transpose: *transpose,
129                },
130                ScheduleItem::Custom { f, args } => {
131                    Func::Custom { f: f.clone(), args: args.to_vec() }
132                }
133                ScheduleItem::Ssa { ssa, args } => {
134                    // TODO: check args vs ssa.args.
135                    if let Some(func) = compilation_cache.get_ssa(ssa.instrs()) {
136                        Func::Kernel { func, args: args.to_vec() }
137                    } else {
138                        let func = self.device.compile(ssa, None)?;
139                        let func = std::sync::Arc::new(func);
140                        compilation_cache.insert_ssa(ssa.instrs().clone(), func.clone());
141                        Func::Kernel { func, args: args.to_vec() }
142                    }
143                }
144                ScheduleItem::Kernel(item) => {
145                    let kernel = item.kernel()?;
146                    let norm_kernel = crate::cache::NormalizedKernel::new(&kernel)?;
147                    if let Some(func) = compilation_cache.get(&norm_kernel) {
148                        // TODO: Simplify args handling so as to ensure that this matches the
149                        // construction of args using ssa.args() below.
150                        let mut args = vec![];
151                        for arg in kernel.args.iter() {
152                            let arg_id = arg.id();
153                            let arg = self.get_arg_id(arg_id)?;
154                            args.push((arg_id, arg.clone()))
155                        }
156                        Func::Kernel { func, args }
157                    } else {
158                        let _guard = self.span_kernel.enter();
159                        let kernel_name =
160                            if kernel.ops.is_empty() { None } else { Some(kernel.ops[0].name()) };
161                        let kernel = kernel.optimize()?;
162                        let opts = if D::use_grid()
163                            && kernel.ops.len() == 1
164                            && kernel.ops[0].layout.rank() >= 1
165                            && kernel.ops[0].layout.num_elements() >= 1
166                        {
167                            let mut dims = kernel.ops[0]
168                                .layout
169                                .dims()
170                                .iter()
171                                .copied()
172                                .enumerate()
173                                .collect::<Vec<_>>();
174                            dims.sort_by(|v1, v2| usize::cmp(&v2.1, &v1.1));
175                            if dims.len() >= 2 && dims[1].1 > 1 {
176                                // TODO: It might be better to use the layout to determine the
177                                // thread dim or to look at reduce dims.
178                                crate::lower_op::Opts::default()
179                                    .with_block_axis(dims[0].0)
180                                    .with_thread_block(dims[1].0, 32)
181                            } else if dims.len() > 1 && dims[0].1 > 1 {
182                                crate::lower_op::Opts::default().with_global_axis(dims[0].0, 32)
183                            } else {
184                                crate::lower_op::Opts::default()
185                            }
186                        } else {
187                            crate::lower_op::Opts::default()
188                        };
189                        let ssa = kernel.lower(&opts)?;
190                        let mut args = vec![];
191                        for arg in ssa.args().iter() {
192                            let arg_id = arg.0.id();
193                            let arg = self.get_arg_id(arg_id)?;
194                            args.push((arg_id, arg.clone()))
195                        }
196                        let func = self.device.compile(&ssa, kernel_name.as_deref())?;
197                        let func = std::sync::Arc::new(func);
198                        compilation_cache.insert(norm_kernel, func.clone());
199                        Func::Kernel { func, args }
200                    }
201                }
202            };
203            funcs.push(call)
204        }
205        let device = self.device.clone();
206        Ok(CompiledSchedule { funcs, device })
207    }
208}
209
210pub enum Func<D: Device> {
211    Kernel {
212        func: std::sync::Arc<D::Func>,
213        args: Args<D>,
214    },
215    MatMul {
216        dst: LazyBuffer<D>,
217        lhs: LazyBuffer<D>,
218        rhs: LazyBuffer<D>,
219        bmnk: (usize, usize, usize, usize),
220        transpose: bool,
221    },
222    Custom {
223        f: crate::lazy_buffer::CustomF<D::Slice>,
224        args: Args<D>,
225    },
226}
227
228pub struct CompiledSchedule<D: Device> {
229    funcs: Vec<Func<D>>,
230    device: D,
231}
232
233impl<D: Device> CompiledSchedule<D> {
234    pub fn run(&self) -> Result<()> {
235        let span_mm = tracing::span!(tracing::Level::TRACE, "mm");
236        let span_k = tracing::span!(tracing::Level::TRACE, "kernel");
237        let span_custom = tracing::span!(tracing::Level::TRACE, "custom");
238        // TODO: We should avoid re-running kernels that have unchanged inputs, tracking
239        // variables/copies is likely enough for this.
240        for func in self.funcs.iter() {
241            match func {
242                Func::Kernel { func, args } => {
243                    let _guard = span_k.enter();
244                    // Should we do some deadlock detection?
245                    let mut bs = args
246                        .iter()
247                        .map(|(_id, lb)| {
248                            unsafe { lb.maybe_allocate_uninit()? };
249                            let b = lb.data().try_borrow_mut()?;
250                            Ok(b)
251                        })
252                        .collect::<Result<Vec<_>>>()?;
253                    let mut bs = bs.iter_mut().map(|b| b.as_mut().unwrap()).collect::<Vec<_>>();
254                    self.device.run(func, &mut bs)?
255                }
256                Func::MatMul { dst, lhs, rhs, bmnk, transpose } => {
257                    let _guard = span_mm.enter();
258                    let lhs_dims = lhs.dims();
259                    let lhs_rank = lhs.rank();
260                    let rhs_dims = rhs.dims();
261                    let rhs_rank = rhs.rank();
262
263                    let lhs_l = if lhs_rank < rhs_rank {
264                        let lhs_dims = [&vec![1; rhs_rank - lhs_rank], lhs_dims].concat();
265                        crate::Layout::from_shape(lhs_dims)
266                    } else {
267                        crate::Layout::from_shape(lhs_dims)
268                    };
269                    let rhs_l = if rhs_rank < lhs_rank {
270                        let rhs_dims = [&vec![1; lhs_rank - rhs_rank], rhs_dims].concat();
271                        crate::Layout::from_shape(rhs_dims)
272                    } else {
273                        crate::Layout::from_shape(rhs_dims)
274                    };
275                    let rhs_l = if *transpose { rhs_l.transpose() } else { rhs_l };
276                    // TODO: provide a nicer api on LazyBuffer to get the underlying buffer and
277                    // have it created if necessary.
278                    unsafe { dst.maybe_allocate_uninit()? };
279                    unsafe { lhs.maybe_allocate_uninit()? };
280                    unsafe { rhs.maybe_allocate_uninit()? };
281                    let mut dst = dst.data().try_borrow_mut()?;
282                    let dst = dst.as_mut().unwrap();
283                    let lhs = lhs.data().try_borrow()?;
284                    let lhs = lhs.as_ref().unwrap();
285                    let rhs = rhs.data().try_borrow()?;
286                    let rhs = rhs.as_ref().unwrap();
287                    self.device.matmul(dst, lhs, rhs, *bmnk, &lhs_l, &rhs_l)?;
288                }
289                Func::Custom { f, args } => {
290                    let _guard = span_custom.enter();
291                    let mut bs = args
292                        .iter()
293                        .map(|(_id, lb)| {
294                            unsafe { lb.maybe_allocate_uninit()? };
295                            let b = lb.data().try_borrow_mut()?;
296                            Ok(b)
297                        })
298                        .collect::<Result<Vec<_>>>()?;
299                    let bs = bs.iter_mut().map(|v| v.as_mut().unwrap()).collect::<Vec<_>>();
300                    f(bs)?;
301                }
302            }
303        }
304        Ok(())
305    }
306}
307
308struct Context<D: Device> {
309    items: Vec<ScheduleItem<D>>,
310    per_arg_id: HashMap<ArgId, LazyBuffer<D>>,
311    ast_cache: HashMap<crate::lazy_buffer::Id, Ast>,
312    id_cnts: HashMap<crate::lazy_buffer::Id, usize>,
313}
314
315impl<D: Device> Context<D> {
316    fn new(id_cnts: HashMap<crate::lazy_buffer::Id, usize>) -> Self {
317        Self { items: vec![], per_arg_id: HashMap::new(), ast_cache: HashMap::new(), id_cnts }
318    }
319
320    fn get_arg_id(&self, arg_id: ArgId) -> Result<&LazyBuffer<D>> {
321        match self.per_arg_id.get(&arg_id) {
322            Some(b) => Ok(b),
323            None => crate::bail!("no arg for id {arg_id:?}"),
324        }
325    }
326
327    fn walk(&mut self, b: &LazyBuffer<D>) -> Result<Ast> {
328        use crate::lazy_buffer::Op;
329
330        let id = b.id();
331        if let Some(ast) = self.ast_cache.get(&id) {
332            return Ok(ast.clone());
333        }
334
335        let dtype = b.dtype();
336        let shape = b.shape();
337        let ast = if b.realized()? {
338            let arg_id = ArgId::new();
339            self.per_arg_id.insert(arg_id, b.clone());
340            crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
341        } else {
342            match b.op() {
343                Op::Unary(op, arg) => {
344                    let ast = self.walk(arg)?;
345                    crate::lang::op::unary(*op, ast)?
346                }
347                Op::Binary(op, lhs, rhs) => {
348                    let lhs = self.walk(lhs)?;
349                    let rhs = self.walk(rhs)?;
350                    crate::lang::op::binary(*op, lhs, rhs)?
351                }
352                Op::MatMul(lhs, rhs, bmnk, transpose) => {
353                    // MatMul currently aren't fused with the rest of the graph. Maybe we should
354                    // allow for custom ops that would be handled in the same way.
355                    let _lhs_id = self.push_schedule_item(lhs)?;
356                    let _rhs_id = self.push_schedule_item(rhs)?;
357                    let dst_id = ArgId::new();
358                    self.per_arg_id.insert(dst_id, b.clone());
359                    self.items.push(ScheduleItem::MatMul {
360                        dst: b.clone(),
361                        lhs: lhs.clone(),
362                        rhs: rhs.clone(),
363                        bmnk: *bmnk,
364                        transpose: *transpose,
365                    });
366                    crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
367                }
368                Op::Reduce(op, arg, axis) => {
369                    let ast = self.walk(arg)?;
370                    crate::lang::op::reduce(*op, ast, *axis)?
371                }
372                Op::Const(cst) => crate::lang::op::cst(*cst)?,
373                Op::Value => {
374                    let arg_id = ArgId::new();
375                    self.per_arg_id.insert(arg_id, b.clone());
376                    crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
377                }
378                Op::Reshape(arg) => {
379                    let dst_id = self.push_schedule_item(arg)?;
380                    crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
381                }
382                Op::Layout(op, arg) => {
383                    let ast = self.walk(arg)?;
384                    crate::lang::op::layout(op.clone(), ast, shape)?
385                }
386                Op::Ssa { ssa, args: b_args } => {
387                    let mut args = Vec::with_capacity(b_args.len() + 1);
388                    for arg in b_args.iter() {
389                        let arg_id = self.push_schedule_item(arg)?;
390                        args.push((arg_id, arg.clone()))
391                    }
392                    let dst_id = ArgId::new();
393                    self.per_arg_id.insert(dst_id, b.clone());
394                    args.push((dst_id, b.clone()));
395                    self.items.push(ScheduleItem::Ssa { ssa: ssa.clone(), args });
396                    crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
397                }
398                Op::Set { values, src, dst_layout } => {
399                    let arg_id = self.push_schedule_item(src)?;
400                    let values = self.walk(values)?;
401                    self.push_kernel(src, values, Some(dst_layout.clone()))?;
402                    crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
403                }
404                Op::CustomIp { f, args: b_args, src } => {
405                    let mut args = Vec::with_capacity(b_args.len() + 1);
406                    for arg in b_args.iter() {
407                        let arg_id = self.push_schedule_item(arg)?;
408                        args.push((arg_id, arg.clone()))
409                    }
410                    let arg_id = self.push_schedule_item(src)?;
411                    args.push((arg_id, src.clone()));
412                    self.items.push(ScheduleItem::Custom { f: f.clone(), args });
413                    crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
414                }
415                Op::Custom { f, args: b_args } => {
416                    let mut args = Vec::with_capacity(b_args.len() + 1);
417                    for arg in b_args.iter() {
418                        let arg_id = self.push_schedule_item(arg)?;
419                        args.push((arg_id, arg.clone()))
420                    }
421                    let dst_id = ArgId::new();
422                    self.per_arg_id.insert(dst_id, b.clone());
423                    args.push((dst_id, b.clone()));
424                    self.items.push(ScheduleItem::Custom { f: f.clone(), args });
425                    crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
426                }
427            }
428        };
429        // When a subtree appears multiple times in the ast, generate a dedicated kernel.
430        let ast = if self.id_cnts.get(&id).copied().unwrap_or(0) > 1 {
431            let dst_id = self.push_kernel(b, ast, None)?;
432            crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
433        } else {
434            ast
435        };
436        self.ast_cache.insert(id, ast.clone());
437        Ok(ast)
438    }
439
440    fn push_kernel(
441        &mut self,
442        buffer: &LazyBuffer<D>,
443        ast: Ast,
444        dst_layout: Option<Layout>,
445    ) -> Result<ArgId> {
446        if let crate::lang::op::AstInner::Load { src: src_arg_id, .. } = ast.inner.as_ref() {
447            let src = self.get_arg_id(*src_arg_id)?;
448            if std::ptr::eq(src.data(), buffer.data()) {
449                // Avoid the cases where we load and store immediately a buffer, this is a no-op
450                // and would result in a deadlock.
451                return Ok(*src_arg_id);
452            }
453        }
454
455        let dst_id = ArgId::new();
456        self.per_arg_id.insert(dst_id, buffer.clone());
457        let mut arg_ids = ast.arg_ids();
458        arg_ids.insert(dst_id);
459        let args = arg_ids
460            .into_iter()
461            .map(|arg_id| {
462                let arg = self.get_arg_id(arg_id)?;
463                Ok((arg_id, arg.clone()))
464            })
465            .collect::<Result<Vec<_>>>()?;
466        let si = KernelItem { ast, dst: (dst_id, buffer.clone()), args, dst_layout };
467        self.items.push(ScheduleItem::Kernel(si));
468        Ok(dst_id)
469    }
470
471    fn push_schedule_item(&mut self, buffer: &LazyBuffer<D>) -> Result<ArgId> {
472        let ast = self.walk(buffer)?;
473        self.push_kernel(buffer, ast, None)
474    }
475}
476
477/// Return the number of uses for each buffer that is reachable from b. The number of uses can be
478/// either 1 or 2 for the case where the buffer is used twice or more.
479/// Note that realized nodes stop the propagation.
480fn id_cnts<D: Device>(
481    b: &LazyBuffer<D>,
482    cnts: &mut HashMap<crate::lazy_buffer::Id, usize>,
483) -> Result<()> {
484    use crate::lazy_buffer::Op;
485
486    if b.realized()? {
487        return Ok(());
488    }
489
490    let id = b.id();
491    let cnt = cnts.entry(id).or_insert(0);
492    *cnt += 1;
493    if *cnt > 1 {
494        return Ok(());
495    }
496    match b.op() {
497        Op::Value | Op::Const(_) => {}
498        Op::Reshape(arg) | Op::Layout(_, arg) | Op::Reduce(_, arg, _) | Op::Unary(_, arg) => {
499            id_cnts(arg, cnts)?
500        }
501        Op::Set { src: arg1, values: arg2, dst_layout: _ }
502        | Op::MatMul(arg1, arg2, _, _)
503        | Op::Binary(_, arg1, arg2) => {
504            id_cnts(arg1, cnts)?;
505            id_cnts(arg2, cnts)?;
506        }
507        Op::CustomIp { f: _, args, src } => {
508            for arg in args.iter() {
509                id_cnts(arg, cnts)?;
510            }
511            id_cnts(src, cnts)?
512        }
513        Op::Ssa { ssa: _, args } | Op::Custom { f: _, args } => {
514            for arg in args.iter() {
515                id_cnts(arg, cnts)?
516            }
517        }
518    }
519    Ok(())
520}