1use crate::lang::op::{ArgId, Ast};
2use crate::{Device, Layout, LazyBuffer, Result};
3use std::collections::HashMap;
4
5type Args<D> = Vec<(ArgId, LazyBuffer<D>)>;
6
7pub struct KernelItem<D: Device> {
8 ast: Ast,
9 dst: (ArgId, LazyBuffer<D>),
10 dst_layout: Option<Layout>,
11 args: Args<D>,
12}
13
14impl<D: Device> KernelItem<D> {
15 pub fn into_ast(self) -> Ast {
16 self.ast
17 }
18
19 pub fn ast(&self) -> &Ast {
20 &self.ast
21 }
22
23 pub fn kernel(&self) -> Result<crate::lang::op::Kernel> {
24 use crate::lang::op;
25 let ast = &self.ast;
26 let dst = &self.dst;
27 let args = self
28 .args
29 .iter()
30 .map(|(id, lb)| op::Arg::new(*id, crate::lang::Type::Ptr(lb.dtype())))
31 .collect::<Vec<_>>();
32 let dst_layout = match &self.dst_layout {
33 None => Layout::from_shape(dst.1.shape()),
34 Some(l) => l.clone(),
35 };
36 let sto = op::store(dst.0, dst_layout, ast.clone())?;
37 let kernel = op::Kernel::new(format!("realize_{:?}", dst.1.id()), args, vec![sto]);
38 Ok(kernel)
39 }
40}
41
42pub enum ScheduleItem<D: Device> {
43 Kernel(KernelItem<D>),
44 MatMul {
45 dst: LazyBuffer<D>,
46 lhs: LazyBuffer<D>,
47 rhs: LazyBuffer<D>,
48 bmnk: (usize, usize, usize, usize),
49 transpose: bool,
50 },
51 Custom {
52 f: crate::lazy_buffer::CustomF<D::Slice>,
53 args: Args<D>,
54 },
55 Ssa {
56 ssa: crate::lang::ssa::Kernel,
57 args: Args<D>,
58 },
59}
60
61pub struct Schedule<D: Device> {
62 items: Vec<ScheduleItem<D>>,
64 per_arg_id: HashMap<ArgId, LazyBuffer<D>>,
65 span_compile: tracing::Span,
66 span_kernel: tracing::Span,
67 device: D,
68}
69
70impl<D: Device> Schedule<D> {
71 pub fn get_arg_id(&self, arg_id: ArgId) -> Result<&LazyBuffer<D>> {
72 match self.per_arg_id.get(&arg_id) {
73 Some(b) => Ok(b),
74 None => crate::bail!("no arg for id {arg_id:?}"),
75 }
76 }
77
78 pub fn create(buffers: &[&LazyBuffer<D>]) -> Result<Self> {
79 let device = if buffers.is_empty() {
80 crate::bail!("no buffers provided")
81 } else {
82 buffers[0].device().clone()
83 };
84 let mut cnts = HashMap::new();
85 for buffer in buffers.iter() {
86 id_cnts(buffer, &mut cnts)?
87 }
88 let mut context = Context::new(cnts);
89 for &buffer in buffers.iter() {
90 context.push_schedule_item(buffer)?;
91 }
92 let span_compile = tracing::span!(tracing::Level::TRACE, "compile");
93 let span_kernel = tracing::span!(tracing::Level::TRACE, "kernel");
94 Ok(Self {
95 items: context.items,
96 device,
97 per_arg_id: context.per_arg_id,
98 span_compile,
99 span_kernel,
100 })
101 }
102
103 pub fn create_one(buffer: &LazyBuffer<D>) -> Result<Self> {
104 Self::create(&[buffer])
105 }
106
107 pub fn items(&self) -> &[ScheduleItem<D>] {
108 self.items.as_slice()
109 }
110
111 pub fn compile(&self) -> Result<CompiledSchedule<D>> {
112 self.compile_with_cache(&mut Default::default())
113 }
114
115 pub fn compile_with_cache(
116 &self,
117 compilation_cache: &mut crate::cache::CompilationCache<D>,
118 ) -> Result<CompiledSchedule<D>> {
119 let _guard = self.span_compile.enter();
120 let mut funcs = Vec::with_capacity(self.items().len());
121 for item in self.items() {
122 let call = match item {
123 ScheduleItem::MatMul { dst, lhs, rhs, bmnk, transpose } => Func::MatMul {
124 dst: dst.clone(),
125 lhs: lhs.clone(),
126 rhs: rhs.clone(),
127 bmnk: *bmnk,
128 transpose: *transpose,
129 },
130 ScheduleItem::Custom { f, args } => {
131 Func::Custom { f: f.clone(), args: args.to_vec() }
132 }
133 ScheduleItem::Ssa { ssa, args } => {
134 if let Some(func) = compilation_cache.get_ssa(ssa.instrs()) {
136 Func::Kernel { func, args: args.to_vec() }
137 } else {
138 let func = self.device.compile(ssa, None)?;
139 let func = std::sync::Arc::new(func);
140 compilation_cache.insert_ssa(ssa.instrs().clone(), func.clone());
141 Func::Kernel { func, args: args.to_vec() }
142 }
143 }
144 ScheduleItem::Kernel(item) => {
145 let kernel = item.kernel()?;
146 let norm_kernel = crate::cache::NormalizedKernel::new(&kernel)?;
147 if let Some(func) = compilation_cache.get(&norm_kernel) {
148 let mut args = vec![];
151 for arg in kernel.args.iter() {
152 let arg_id = arg.id();
153 let arg = self.get_arg_id(arg_id)?;
154 args.push((arg_id, arg.clone()))
155 }
156 Func::Kernel { func, args }
157 } else {
158 let _guard = self.span_kernel.enter();
159 let kernel_name =
160 if kernel.ops.is_empty() { None } else { Some(kernel.ops[0].name()) };
161 let kernel = kernel.optimize()?;
162 let opts = if D::use_grid()
163 && kernel.ops.len() == 1
164 && kernel.ops[0].layout.rank() >= 1
165 && kernel.ops[0].layout.num_elements() >= 1
166 {
167 let mut dims = kernel.ops[0]
168 .layout
169 .dims()
170 .iter()
171 .copied()
172 .enumerate()
173 .collect::<Vec<_>>();
174 dims.sort_by(|v1, v2| usize::cmp(&v2.1, &v1.1));
175 if dims.len() >= 2 && dims[1].1 > 1 {
176 crate::lower_op::Opts::default()
179 .with_block_axis(dims[0].0)
180 .with_thread_block(dims[1].0, 32)
181 } else if dims.len() > 1 && dims[0].1 > 1 {
182 crate::lower_op::Opts::default().with_global_axis(dims[0].0, 32)
183 } else {
184 crate::lower_op::Opts::default()
185 }
186 } else {
187 crate::lower_op::Opts::default()
188 };
189 let ssa = kernel.lower(&opts)?;
190 let mut args = vec![];
191 for arg in ssa.args().iter() {
192 let arg_id = arg.0.id();
193 let arg = self.get_arg_id(arg_id)?;
194 args.push((arg_id, arg.clone()))
195 }
196 let func = self.device.compile(&ssa, kernel_name.as_deref())?;
197 let func = std::sync::Arc::new(func);
198 compilation_cache.insert(norm_kernel, func.clone());
199 Func::Kernel { func, args }
200 }
201 }
202 };
203 funcs.push(call)
204 }
205 let device = self.device.clone();
206 Ok(CompiledSchedule { funcs, device })
207 }
208}
209
210pub enum Func<D: Device> {
211 Kernel {
212 func: std::sync::Arc<D::Func>,
213 args: Args<D>,
214 },
215 MatMul {
216 dst: LazyBuffer<D>,
217 lhs: LazyBuffer<D>,
218 rhs: LazyBuffer<D>,
219 bmnk: (usize, usize, usize, usize),
220 transpose: bool,
221 },
222 Custom {
223 f: crate::lazy_buffer::CustomF<D::Slice>,
224 args: Args<D>,
225 },
226}
227
228pub struct CompiledSchedule<D: Device> {
229 funcs: Vec<Func<D>>,
230 device: D,
231}
232
233impl<D: Device> CompiledSchedule<D> {
234 pub fn run(&self) -> Result<()> {
235 let span_mm = tracing::span!(tracing::Level::TRACE, "mm");
236 let span_k = tracing::span!(tracing::Level::TRACE, "kernel");
237 let span_custom = tracing::span!(tracing::Level::TRACE, "custom");
238 for func in self.funcs.iter() {
241 match func {
242 Func::Kernel { func, args } => {
243 let _guard = span_k.enter();
244 let mut bs = args
246 .iter()
247 .map(|(_id, lb)| {
248 unsafe { lb.maybe_allocate_uninit()? };
249 let b = lb.data().try_borrow_mut()?;
250 Ok(b)
251 })
252 .collect::<Result<Vec<_>>>()?;
253 let mut bs = bs.iter_mut().map(|b| b.as_mut().unwrap()).collect::<Vec<_>>();
254 self.device.run(func, &mut bs)?
255 }
256 Func::MatMul { dst, lhs, rhs, bmnk, transpose } => {
257 let _guard = span_mm.enter();
258 let lhs_dims = lhs.dims();
259 let lhs_rank = lhs.rank();
260 let rhs_dims = rhs.dims();
261 let rhs_rank = rhs.rank();
262
263 let lhs_l = if lhs_rank < rhs_rank {
264 let lhs_dims = [&vec![1; rhs_rank - lhs_rank], lhs_dims].concat();
265 crate::Layout::from_shape(lhs_dims)
266 } else {
267 crate::Layout::from_shape(lhs_dims)
268 };
269 let rhs_l = if rhs_rank < lhs_rank {
270 let rhs_dims = [&vec![1; lhs_rank - rhs_rank], rhs_dims].concat();
271 crate::Layout::from_shape(rhs_dims)
272 } else {
273 crate::Layout::from_shape(rhs_dims)
274 };
275 let rhs_l = if *transpose { rhs_l.transpose() } else { rhs_l };
276 unsafe { dst.maybe_allocate_uninit()? };
279 unsafe { lhs.maybe_allocate_uninit()? };
280 unsafe { rhs.maybe_allocate_uninit()? };
281 let mut dst = dst.data().try_borrow_mut()?;
282 let dst = dst.as_mut().unwrap();
283 let lhs = lhs.data().try_borrow()?;
284 let lhs = lhs.as_ref().unwrap();
285 let rhs = rhs.data().try_borrow()?;
286 let rhs = rhs.as_ref().unwrap();
287 self.device.matmul(dst, lhs, rhs, *bmnk, &lhs_l, &rhs_l)?;
288 }
289 Func::Custom { f, args } => {
290 let _guard = span_custom.enter();
291 let mut bs = args
292 .iter()
293 .map(|(_id, lb)| {
294 unsafe { lb.maybe_allocate_uninit()? };
295 let b = lb.data().try_borrow_mut()?;
296 Ok(b)
297 })
298 .collect::<Result<Vec<_>>>()?;
299 let bs = bs.iter_mut().map(|v| v.as_mut().unwrap()).collect::<Vec<_>>();
300 f(bs)?;
301 }
302 }
303 }
304 Ok(())
305 }
306}
307
308struct Context<D: Device> {
309 items: Vec<ScheduleItem<D>>,
310 per_arg_id: HashMap<ArgId, LazyBuffer<D>>,
311 ast_cache: HashMap<crate::lazy_buffer::Id, Ast>,
312 id_cnts: HashMap<crate::lazy_buffer::Id, usize>,
313}
314
315impl<D: Device> Context<D> {
316 fn new(id_cnts: HashMap<crate::lazy_buffer::Id, usize>) -> Self {
317 Self { items: vec![], per_arg_id: HashMap::new(), ast_cache: HashMap::new(), id_cnts }
318 }
319
320 fn get_arg_id(&self, arg_id: ArgId) -> Result<&LazyBuffer<D>> {
321 match self.per_arg_id.get(&arg_id) {
322 Some(b) => Ok(b),
323 None => crate::bail!("no arg for id {arg_id:?}"),
324 }
325 }
326
327 fn walk(&mut self, b: &LazyBuffer<D>) -> Result<Ast> {
328 use crate::lazy_buffer::Op;
329
330 let id = b.id();
331 if let Some(ast) = self.ast_cache.get(&id) {
332 return Ok(ast.clone());
333 }
334
335 let dtype = b.dtype();
336 let shape = b.shape();
337 let ast = if b.realized()? {
338 let arg_id = ArgId::new();
339 self.per_arg_id.insert(arg_id, b.clone());
340 crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
341 } else {
342 match b.op() {
343 Op::Unary(op, arg) => {
344 let ast = self.walk(arg)?;
345 crate::lang::op::unary(*op, ast)?
346 }
347 Op::Binary(op, lhs, rhs) => {
348 let lhs = self.walk(lhs)?;
349 let rhs = self.walk(rhs)?;
350 crate::lang::op::binary(*op, lhs, rhs)?
351 }
352 Op::MatMul(lhs, rhs, bmnk, transpose) => {
353 let _lhs_id = self.push_schedule_item(lhs)?;
356 let _rhs_id = self.push_schedule_item(rhs)?;
357 let dst_id = ArgId::new();
358 self.per_arg_id.insert(dst_id, b.clone());
359 self.items.push(ScheduleItem::MatMul {
360 dst: b.clone(),
361 lhs: lhs.clone(),
362 rhs: rhs.clone(),
363 bmnk: *bmnk,
364 transpose: *transpose,
365 });
366 crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
367 }
368 Op::Reduce(op, arg, axis) => {
369 let ast = self.walk(arg)?;
370 crate::lang::op::reduce(*op, ast, *axis)?
371 }
372 Op::Const(cst) => crate::lang::op::cst(*cst)?,
373 Op::Value => {
374 let arg_id = ArgId::new();
375 self.per_arg_id.insert(arg_id, b.clone());
376 crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
377 }
378 Op::Reshape(arg) => {
379 let dst_id = self.push_schedule_item(arg)?;
380 crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
381 }
382 Op::Layout(op, arg) => {
383 let ast = self.walk(arg)?;
384 crate::lang::op::layout(op.clone(), ast, shape)?
385 }
386 Op::Ssa { ssa, args: b_args } => {
387 let mut args = Vec::with_capacity(b_args.len() + 1);
388 for arg in b_args.iter() {
389 let arg_id = self.push_schedule_item(arg)?;
390 args.push((arg_id, arg.clone()))
391 }
392 let dst_id = ArgId::new();
393 self.per_arg_id.insert(dst_id, b.clone());
394 args.push((dst_id, b.clone()));
395 self.items.push(ScheduleItem::Ssa { ssa: ssa.clone(), args });
396 crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
397 }
398 Op::Set { values, src, dst_layout } => {
399 let arg_id = self.push_schedule_item(src)?;
400 let values = self.walk(values)?;
401 self.push_kernel(src, values, Some(dst_layout.clone()))?;
402 crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
403 }
404 Op::CustomIp { f, args: b_args, src } => {
405 let mut args = Vec::with_capacity(b_args.len() + 1);
406 for arg in b_args.iter() {
407 let arg_id = self.push_schedule_item(arg)?;
408 args.push((arg_id, arg.clone()))
409 }
410 let arg_id = self.push_schedule_item(src)?;
411 args.push((arg_id, src.clone()));
412 self.items.push(ScheduleItem::Custom { f: f.clone(), args });
413 crate::lang::op::load(arg_id, Layout::from_shape(shape), dtype)?
414 }
415 Op::Custom { f, args: b_args } => {
416 let mut args = Vec::with_capacity(b_args.len() + 1);
417 for arg in b_args.iter() {
418 let arg_id = self.push_schedule_item(arg)?;
419 args.push((arg_id, arg.clone()))
420 }
421 let dst_id = ArgId::new();
422 self.per_arg_id.insert(dst_id, b.clone());
423 args.push((dst_id, b.clone()));
424 self.items.push(ScheduleItem::Custom { f: f.clone(), args });
425 crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
426 }
427 }
428 };
429 let ast = if self.id_cnts.get(&id).copied().unwrap_or(0) > 1 {
431 let dst_id = self.push_kernel(b, ast, None)?;
432 crate::lang::op::load(dst_id, Layout::from_shape(shape), dtype)?
433 } else {
434 ast
435 };
436 self.ast_cache.insert(id, ast.clone());
437 Ok(ast)
438 }
439
440 fn push_kernel(
441 &mut self,
442 buffer: &LazyBuffer<D>,
443 ast: Ast,
444 dst_layout: Option<Layout>,
445 ) -> Result<ArgId> {
446 if let crate::lang::op::AstInner::Load { src: src_arg_id, .. } = ast.inner.as_ref() {
447 let src = self.get_arg_id(*src_arg_id)?;
448 if std::ptr::eq(src.data(), buffer.data()) {
449 return Ok(*src_arg_id);
452 }
453 }
454
455 let dst_id = ArgId::new();
456 self.per_arg_id.insert(dst_id, buffer.clone());
457 let mut arg_ids = ast.arg_ids();
458 arg_ids.insert(dst_id);
459 let args = arg_ids
460 .into_iter()
461 .map(|arg_id| {
462 let arg = self.get_arg_id(arg_id)?;
463 Ok((arg_id, arg.clone()))
464 })
465 .collect::<Result<Vec<_>>>()?;
466 let si = KernelItem { ast, dst: (dst_id, buffer.clone()), args, dst_layout };
467 self.items.push(ScheduleItem::Kernel(si));
468 Ok(dst_id)
469 }
470
471 fn push_schedule_item(&mut self, buffer: &LazyBuffer<D>) -> Result<ArgId> {
472 let ast = self.walk(buffer)?;
473 self.push_kernel(buffer, ast, None)
474 }
475}
476
477fn id_cnts<D: Device>(
481 b: &LazyBuffer<D>,
482 cnts: &mut HashMap<crate::lazy_buffer::Id, usize>,
483) -> Result<()> {
484 use crate::lazy_buffer::Op;
485
486 if b.realized()? {
487 return Ok(());
488 }
489
490 let id = b.id();
491 let cnt = cnts.entry(id).or_insert(0);
492 *cnt += 1;
493 if *cnt > 1 {
494 return Ok(());
495 }
496 match b.op() {
497 Op::Value | Op::Const(_) => {}
498 Op::Reshape(arg) | Op::Layout(_, arg) | Op::Reduce(_, arg, _) | Op::Unary(_, arg) => {
499 id_cnts(arg, cnts)?
500 }
501 Op::Set { src: arg1, values: arg2, dst_layout: _ }
502 | Op::MatMul(arg1, arg2, _, _)
503 | Op::Binary(_, arg1, arg2) => {
504 id_cnts(arg1, cnts)?;
505 id_cnts(arg2, cnts)?;
506 }
507 Op::CustomIp { f: _, args, src } => {
508 for arg in args.iter() {
509 id_cnts(arg, cnts)?;
510 }
511 id_cnts(src, cnts)?
512 }
513 Op::Ssa { ssa: _, args } | Op::Custom { f: _, args } => {
514 for arg in args.iter() {
515 id_cnts(arg, cnts)?
516 }
517 }
518 }
519 Ok(())
520}