ug/
lib.rs

1pub mod block;
2pub mod cache;
3pub mod common_tests;
4pub mod r#const;
5pub mod cpu_code_gen;
6pub mod cpu_runtime;
7pub mod display;
8pub mod dtype;
9pub mod error;
10pub mod interpreter;
11pub mod lang;
12pub mod layout;
13pub mod lazy_buffer;
14pub mod lower;
15pub mod lower_op;
16pub mod safetensors;
17pub mod samples;
18pub mod schedule;
19pub mod utils;
20
21pub use cpu_runtime::{CpuDevice, CpuStorage, CpuStorageRef, CpuStorageRefMut};
22pub use dtype::{DType, WithDType};
23pub use error::{Error, Result};
24pub use layout::{Dim, Layout, Shape, D};
25pub use lazy_buffer::LazyBuffer;
26pub use r#const::Const;
27pub use schedule::{Schedule, ScheduleItem};
28
29pub trait Slice: std::fmt::Debug {
30    type Device: Device<Slice = Self>;
31
32    fn device(&self) -> &Self::Device;
33    fn dtype(&self) -> DType;
34    fn len(&self) -> usize;
35    fn copy_host_to_device<DT: WithDType>(&mut self, src: &[DT]) -> Result<()>;
36    fn copy_device_to_host<DT: WithDType>(&self, dst: &mut [DT]) -> Result<()>;
37
38    fn is_empty(&self) -> bool {
39        self.len() == 0
40    }
41
42    fn to_vec<DT: WithDType>(&self) -> Result<Vec<DT>> {
43        let mut host = vec![DT::zero(); self.len()];
44        self.copy_device_to_host(&mut host)?;
45        Ok(host)
46    }
47}
48
49pub trait Device: Clone + std::fmt::Debug {
50    type Slice: Slice<Device = Self>;
51    type Func;
52
53    #[allow(clippy::missing_safety_doc)]
54    unsafe fn allocate_uninit(&self, dtype: DType, len: usize) -> Result<Self::Slice>;
55    fn synchronize(&self) -> Result<()>;
56    fn compile(&self, kernel: &crate::lang::ssa::Kernel, name: Option<&str>) -> Result<Self::Func>;
57    // TODO: currently const parameters are hardcoded in the kernel and new code is generated for
58    // these when necessary. Maybe we should have a more generic arg type that could handle
59    // `Const` scalars.
60    fn run(&self, f: &Self::Func, args: &mut [&mut Self::Slice]) -> Result<()>;
61
62    fn matmul(
63        &self,
64        _dst: &mut Self::Slice,
65        _lhs: &Self::Slice,
66        _rhs: &Self::Slice,
67        _bmnk: (usize, usize, usize, usize),
68        _lhs_l: &Layout,
69        _rhs_l: &Layout,
70    ) -> Result<()>;
71
72    fn use_grid() -> bool;
73}