ug/
cpu_runtime.rs

1use crate::{DType, Error, Layout, Result};
2use half::{bf16, f16};
3use std::path::PathBuf;
4
5#[derive(Debug, Clone)]
6pub enum CpuStorage {
7    BF16(Vec<bf16>),
8    F16(Vec<f16>),
9    F32(Vec<f32>),
10    I32(Vec<i32>),
11    I64(Vec<i64>),
12}
13
14// Poor man's GADT...
15#[derive(Debug, Copy, Clone)]
16pub enum CpuStorageRef<'a> {
17    BF16(&'a [bf16]),
18    F16(&'a [f16]),
19    F32(&'a [f32]),
20    I32(&'a [i32]),
21    I64(&'a [i64]),
22}
23
24#[derive(Debug)]
25pub enum CpuStorageRefMut<'a> {
26    BF16(&'a mut [bf16]),
27    F16(&'a mut [f16]),
28    F32(&'a mut [f32]),
29    I32(&'a mut [i32]),
30    I64(&'a mut [i64]),
31}
32
33impl From<Vec<bf16>> for CpuStorage {
34    fn from(value: Vec<bf16>) -> Self {
35        Self::BF16(value)
36    }
37}
38
39impl From<Vec<f16>> for CpuStorage {
40    fn from(value: Vec<f16>) -> Self {
41        Self::F16(value)
42    }
43}
44
45impl From<Vec<f32>> for CpuStorage {
46    fn from(value: Vec<f32>) -> Self {
47        Self::F32(value)
48    }
49}
50
51impl From<Vec<i32>> for CpuStorage {
52    fn from(value: Vec<i32>) -> Self {
53        Self::I32(value)
54    }
55}
56
57impl From<Vec<i64>> for CpuStorage {
58    fn from(value: Vec<i64>) -> Self {
59        Self::I64(value)
60    }
61}
62
63impl CpuStorage {
64    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
65        match self {
66            Self::BF16(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
67            Self::F16(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
68            Self::F32(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
69            Self::I32(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
70            Self::I64(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
71        }
72    }
73
74    pub fn as_ptr(&mut self) -> *const std::ffi::c_void {
75        match self {
76            Self::BF16(s) => s.as_ptr() as *const std::ffi::c_void,
77            Self::F16(s) => s.as_ptr() as *const std::ffi::c_void,
78            Self::F32(s) => s.as_ptr() as *const std::ffi::c_void,
79            Self::I32(s) => s.as_ptr() as *const std::ffi::c_void,
80            Self::I64(s) => s.as_ptr() as *const std::ffi::c_void,
81        }
82    }
83
84    pub fn len(&self) -> usize {
85        match self {
86            Self::BF16(s) => s.len(),
87            Self::F16(s) => s.len(),
88            Self::F32(s) => s.len(),
89            Self::I32(s) => s.len(),
90            Self::I64(s) => s.len(),
91        }
92    }
93
94    pub fn is_empty(&self) -> bool {
95        self.len() == 0
96    }
97
98    pub fn dtype(&self) -> DType {
99        match self {
100            Self::BF16(_) => DType::BF16,
101            Self::F16(_) => DType::F16,
102            Self::F32(_) => DType::F32,
103            Self::I32(_) => DType::I32,
104            Self::I64(_) => DType::I64,
105        }
106    }
107
108    pub fn as_ref(&self) -> CpuStorageRef<'_> {
109        match self {
110            Self::BF16(v) => CpuStorageRef::BF16(v.as_slice()),
111            Self::F16(v) => CpuStorageRef::F16(v.as_slice()),
112            Self::F32(v) => CpuStorageRef::F32(v.as_slice()),
113            Self::I32(v) => CpuStorageRef::I32(v.as_slice()),
114            Self::I64(v) => CpuStorageRef::I64(v.as_slice()),
115        }
116    }
117
118    pub fn as_mut_ref(&mut self) -> CpuStorageRefMut<'_> {
119        match self {
120            Self::BF16(v) => CpuStorageRefMut::BF16(v.as_mut_slice()),
121            Self::F16(v) => CpuStorageRefMut::F16(v.as_mut_slice()),
122            Self::F32(v) => CpuStorageRefMut::F32(v.as_mut_slice()),
123            Self::I32(v) => CpuStorageRefMut::I32(v.as_mut_slice()),
124            Self::I64(v) => CpuStorageRefMut::I64(v.as_mut_slice()),
125        }
126    }
127
128    pub fn data<T: crate::WithDType>(&self) -> Result<&[T]> {
129        T::from_cpu_storage(self.as_ref())
130    }
131
132    pub fn data_mut<T: crate::WithDType>(&mut self) -> Result<&mut [T]> {
133        T::from_cpu_storage_mut(self.as_mut_ref())
134    }
135}
136
137impl CpuStorageRef<'_> {
138    pub fn dtype(&self) -> DType {
139        match self {
140            Self::BF16(_) => DType::BF16,
141            Self::F16(_) => DType::F16,
142            Self::F32(_) => DType::F32,
143            Self::I32(_) => DType::I32,
144            Self::I64(_) => DType::I64,
145        }
146    }
147
148    pub fn len(&self) -> usize {
149        match self {
150            Self::BF16(s) => s.len(),
151            Self::F16(s) => s.len(),
152            Self::F32(s) => s.len(),
153            Self::I32(s) => s.len(),
154            Self::I64(s) => s.len(),
155        }
156    }
157
158    pub fn is_empty(&self) -> bool {
159        self.len() == 0
160    }
161}
162
163impl<'a> From<&'a [bf16]> for CpuStorageRef<'a> {
164    fn from(value: &'a [bf16]) -> Self {
165        Self::BF16(value)
166    }
167}
168
169impl<'a> From<&'a [f16]> for CpuStorageRef<'a> {
170    fn from(value: &'a [f16]) -> Self {
171        Self::F16(value)
172    }
173}
174
175impl<'a> From<&'a [f32]> for CpuStorageRef<'a> {
176    fn from(value: &'a [f32]) -> Self {
177        Self::F32(value)
178    }
179}
180
181impl<'a> From<&'a [i32]> for CpuStorageRef<'a> {
182    fn from(value: &'a [i32]) -> Self {
183        Self::I32(value)
184    }
185}
186
187impl<'a> From<&'a [i64]> for CpuStorageRef<'a> {
188    fn from(value: &'a [i64]) -> Self {
189        Self::I64(value)
190    }
191}
192
193impl CpuStorageRefMut<'_> {
194    pub fn dtype(&self) -> DType {
195        match self {
196            Self::BF16(_) => DType::BF16,
197            Self::F16(_) => DType::F16,
198            Self::F32(_) => DType::F32,
199            Self::I32(_) => DType::I32,
200            Self::I64(_) => DType::I64,
201        }
202    }
203    pub fn len(&self) -> usize {
204        match self {
205            Self::BF16(s) => s.len(),
206            Self::F16(s) => s.len(),
207            Self::F32(s) => s.len(),
208            Self::I32(s) => s.len(),
209            Self::I64(s) => s.len(),
210        }
211    }
212
213    pub fn is_empty(&self) -> bool {
214        self.len() == 0
215    }
216}
217
218impl<'a> From<&'a mut [bf16]> for CpuStorageRefMut<'a> {
219    fn from(value: &'a mut [bf16]) -> Self {
220        Self::BF16(value)
221    }
222}
223
224impl<'a> From<&'a mut [f16]> for CpuStorageRefMut<'a> {
225    fn from(value: &'a mut [f16]) -> Self {
226        Self::F16(value)
227    }
228}
229
230impl<'a> From<&'a mut [f32]> for CpuStorageRefMut<'a> {
231    fn from(value: &'a mut [f32]) -> Self {
232        Self::F32(value)
233    }
234}
235
236impl<'a> From<&'a mut [i32]> for CpuStorageRefMut<'a> {
237    fn from(value: &'a mut [i32]) -> Self {
238        Self::I32(value)
239    }
240}
241
242impl<'a> From<&'a mut [i64]> for CpuStorageRefMut<'a> {
243    fn from(value: &'a mut [i64]) -> Self {
244        Self::I64(value)
245    }
246}
247
248#[derive(Clone, Copy, Debug)]
249pub struct CpuDevice;
250
251impl crate::Device for CpuDevice {
252    type Slice = CpuStorage;
253    type Func = Func;
254
255    unsafe fn allocate_uninit(&self, dtype: DType, len: usize) -> Result<Self::Slice> {
256        let slice = match dtype {
257            DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; len]),
258            DType::F16 => CpuStorage::F16(vec![f16::ZERO; len]),
259            DType::F32 => CpuStorage::F32(vec![0f32; len]),
260            DType::I32 => CpuStorage::I32(vec![0i32; len]),
261            DType::I64 => CpuStorage::I64(vec![0i64; len]),
262        };
263        Ok(slice)
264    }
265
266    fn synchronize(&self) -> Result<()> {
267        Ok(())
268    }
269
270    fn use_grid() -> bool {
271        false
272    }
273
274    fn compile(&self, kernel: &crate::lang::ssa::Kernel, name: Option<&str>) -> Result<Self::Func> {
275        let mut c_code = Vec::with_capacity(8192);
276        // Compilation for the cpu runtime uses temporary files, so we use the pid to ensure that
277        // there is no name collision.
278        let pid = std::process::id();
279        let kernel_id = KernelId::new().as_usize();
280        let func_name = match name {
281            Some(name) => format!("ugc_{name}_{pid}_{kernel_id}"),
282            None => format!("ugc_{pid}_{kernel_id}"),
283        };
284        crate::cpu_code_gen::gen(&mut c_code, &func_name, kernel)?;
285        self.compile_c(&c_code, func_name)
286    }
287
288    fn run(&self, f: &Self::Func, args: &mut [&mut Self::Slice]) -> Result<()> {
289        use libloading::Symbol as S;
290        use std::ffi::c_void;
291
292        let func_name = f.func_name.as_bytes();
293        // TODO: For the calls below to be safe, we should store the kernel signature in Func
294        // and check that args matches it.
295        match args {
296            [] => {
297                let symbol: S<extern "C" fn()> = unsafe { f.lib.get(func_name)? };
298                symbol()
299            }
300            [a1] => {
301                let symbol: S<extern "C" fn(*mut c_void)> = unsafe { f.lib.get(func_name)? };
302                symbol(a1.as_mut_ptr())
303            }
304            [a1, a2] => {
305                let symbol: S<extern "C" fn(*mut c_void, *mut c_void)> =
306                    unsafe { f.lib.get(func_name)? };
307                symbol(a1.as_mut_ptr(), a2.as_mut_ptr())
308            }
309            [a1, a2, a3] => {
310                let symbol: S<extern "C" fn(*mut c_void, *mut c_void, *mut c_void)> =
311                    unsafe { f.lib.get(func_name)? };
312                symbol(a1.as_mut_ptr(), a2.as_mut_ptr(), a3.as_mut_ptr())
313            }
314            [a1, a2, a3, a4] => {
315                let symbol: S<extern "C" fn(*mut c_void, *mut c_void, *mut c_void, *mut c_void)> =
316                    unsafe { f.lib.get(func_name)? };
317                symbol(a1.as_mut_ptr(), a2.as_mut_ptr(), a3.as_mut_ptr(), a4.as_mut_ptr())
318            }
319            [a1, a2, a3, a4, a5] => {
320                let symbol: S<
321                    extern "C" fn(*mut c_void, *mut c_void, *mut c_void, *mut c_void, *mut c_void),
322                > = unsafe { f.lib.get(func_name)? };
323                symbol(
324                    a1.as_mut_ptr(),
325                    a2.as_mut_ptr(),
326                    a3.as_mut_ptr(),
327                    a4.as_mut_ptr(),
328                    a5.as_mut_ptr(),
329                )
330            }
331            _ => crate::bail!("unsupported number of args for kernel {}", args.len()),
332        }
333        Ok(())
334    }
335
336    fn matmul(
337        &self,
338        dst: &mut Self::Slice,
339        lhs: &Self::Slice,
340        rhs: &Self::Slice,
341        bmnk: (usize, usize, usize, usize),
342        lhs_l: &Layout,
343        rhs_l: &Layout,
344    ) -> Result<()> {
345        use CpuStorage::{F16, F32};
346        let mm = MatMul(bmnk);
347        let (dst_dt, lhs_dt, rhs_dt) = (dst.dtype(), lhs.dtype(), rhs.dtype());
348        match (dst, lhs, rhs) {
349            (F16(dst), F16(lhs), F16(rhs)) => mm.gemm(dst, lhs, lhs_l, rhs, rhs_l)?,
350            (F32(dst), F32(lhs), F32(rhs)) => mm.gemm(dst, lhs, lhs_l, rhs, rhs_l)?,
351            _ => {
352                crate::bail!(
353                    "incorrect dtypes for matmul, dst: {dst_dt:?}, lhs: {lhs_dt:?}, rhs: {rhs_dt:?}"
354                )
355            }
356        }
357        Ok(())
358    }
359}
360
361impl crate::Slice for CpuStorage {
362    type Device = CpuDevice;
363
364    fn len(&self) -> usize {
365        CpuStorage::len(self)
366    }
367
368    fn dtype(&self) -> crate::DType {
369        CpuStorage::dtype(self)
370    }
371
372    fn device(&self) -> &Self::Device {
373        &CpuDevice
374    }
375
376    fn copy_host_to_device<DT: crate::WithDType>(&mut self, src: &[DT]) -> Result<()> {
377        use CpuStorage as S;
378        use CpuStorageRef as C;
379        let dtype = self.dtype();
380        if src.len() != self.len() {
381            crate::bail!("dtoh len mismatch, dst {}, len {}", self.len(), src.len())
382        }
383        match (self, DT::to_cpu_storage(src)) {
384            (S::BF16(dst), C::BF16(src)) => dst.copy_from_slice(src),
385            (S::F16(dst), C::F16(src)) => dst.copy_from_slice(src),
386            (S::F32(dst), C::F32(src)) => dst.copy_from_slice(src),
387            (S::I32(dst), C::I32(src)) => dst.copy_from_slice(src),
388            (S::I64(dst), C::I64(src)) => dst.copy_from_slice(src),
389            (_, _) => {
390                crate::bail!("htod dtype mismatch, dst {dtype:?}, src {:?}", DT::DTYPE)
391            }
392        }
393        Ok(())
394    }
395
396    fn copy_device_to_host<DT: crate::WithDType>(&self, dst: &mut [DT]) -> Result<()> {
397        use CpuStorage as S;
398        use CpuStorageRefMut as C;
399        let dtype = self.dtype();
400        if dst.len() != self.len() {
401            crate::bail!("dtoh len mismatch, dst {}, len {}", dst.len(), self.len())
402        }
403        match (self, DT::to_cpu_storage_mut(dst)) {
404            (S::BF16(src), C::BF16(dst)) => dst.copy_from_slice(src),
405            (S::F16(src), C::F16(dst)) => dst.copy_from_slice(src),
406            (S::F32(src), C::F32(dst)) => dst.copy_from_slice(src),
407            (S::I32(src), C::I32(dst)) => dst.copy_from_slice(src),
408            (S::I64(src), C::I64(dst)) => dst.copy_from_slice(src),
409            (_, _) => crate::bail!("dtoh dtype mismatch, dst {:?}, src {dtype:?}", DT::DTYPE),
410        }
411        Ok(())
412    }
413}
414
415#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
416pub struct KernelId(usize);
417
418impl KernelId {
419    pub(crate) fn new() -> Self {
420        // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
421        use std::sync::atomic;
422        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
423        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
424    }
425
426    pub fn as_usize(&self) -> usize {
427        self.0
428    }
429}
430
431pub struct Func {
432    func_name: String,
433    lib: libloading::Library,
434}
435
436impl Func {
437    pub fn name(&self) -> &str {
438        self.func_name.as_str()
439    }
440
441    #[allow(clippy::missing_safety_doc)]
442    pub unsafe fn run0(&self) -> Result<()> {
443        let func_name = self.func_name.as_bytes();
444        let symbol: libloading::Symbol<unsafe extern "C" fn()> = self.lib.get(func_name)?;
445        symbol();
446        Ok(())
447    }
448
449    #[allow(clippy::missing_safety_doc)]
450    pub unsafe fn run3<T>(&self, v1: &mut [T], v2: &mut [T], v3: &mut [T]) -> Result<()> {
451        use std::ffi::c_void;
452
453        let func_name = self.func_name.as_bytes();
454        let symbol: libloading::Symbol<
455            unsafe extern "C" fn(*mut c_void, *mut c_void, *mut c_void),
456        > = self.lib.get(func_name)?;
457        symbol(
458            v1.as_mut_ptr() as *mut c_void,
459            v2.as_mut_ptr() as *mut c_void,
460            v3.as_mut_ptr() as *mut c_void,
461        );
462        Ok(())
463    }
464}
465
466impl crate::CpuDevice {
467    pub fn compile_c(&self, c_code: &[u8], func_name: String) -> Result<Func> {
468        fn compile_inner(
469            c_code: &[u8],
470            func_name: String,
471            tmp_c: &PathBuf,
472            tmp_so: &PathBuf,
473        ) -> Result<Func> {
474            std::fs::write(tmp_c, c_code)?;
475            // TODO: add some environment variable or other ways to set some flags.
476            let output = std::process::Command::new("gcc")
477                .arg(tmp_c)
478                .args([
479                    "-shared",
480                    "-lm",
481                    "-O3",
482                    "-march=native",
483                    "-ffast-math",
484                    "-fomit-frame-pointer",
485                    "-o",
486                ])
487                .arg(tmp_so)
488                .output()?;
489
490            if !output.status.success() {
491                crate::bail!(
492                    "compilation failed\nstdout:\n{}\nstderr:{}",
493                    String::from_utf8_lossy(&output.stdout),
494                    String::from_utf8_lossy(&output.stderr)
495                )
496            }
497            let lib = unsafe { libloading::Library::new(tmp_so)? };
498            Ok(Func { func_name, lib })
499        }
500
501        let tmp_dir = std::env::temp_dir();
502        let tmp_c = tmp_dir.join(format!("{func_name}.c"));
503        let tmp_so = tmp_dir.join(format!("{func_name}.so"));
504        let result = compile_inner(c_code, func_name, &tmp_c, &tmp_so);
505        // Ensure that the temporary files are cleaned up, even on failures.
506        if !crate::utils::KEEP_TMP.with(|b| *b) {
507            let _ = std::fs::remove_file(tmp_c);
508            let _ = std::fs::remove_file(tmp_so);
509        }
510        result
511    }
512}
513
514pub struct MatMul((usize, usize, usize, usize));
515
516impl MatMul {
517    fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
518        Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
519            lhs_l: lhs_l.clone(),
520            rhs_l: rhs_l.clone(),
521            bmnk: self.0,
522            msg,
523        }))
524        .bt()
525    }
526
527    fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
528        let lhs_stride = lhs_l.strides();
529        let rhs_stride = rhs_l.strides();
530        let rank = lhs_stride.len();
531        let (_b, m, n, k) = self.0;
532        let a_skip: usize = match lhs_stride[..rank - 2] {
533            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
534            [_, stride] if lhs_l.dims()[0] == 1 => stride,
535            [stride, _] if lhs_l.dims()[1] == 1 => stride,
536            [stride] => stride,
537            [] => m * k,
538            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
539        };
540        let b_skip: usize = match rhs_stride[..rank - 2] {
541            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
542            [_, stride] if rhs_l.dims()[0] == 1 => stride,
543            [stride, _] if rhs_l.dims()[1] == 1 => stride,
544            [stride] => stride,
545            [] => n * k,
546            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
547        };
548        Ok((a_skip, b_skip))
549    }
550
551    pub fn gemm<T: crate::WithDType>(
552        &self,
553        dst: &mut [T],
554        lhs: &[T],
555        lhs_l: &Layout,
556        rhs: &[T],
557        rhs_l: &Layout,
558    ) -> Result<()> {
559        use gemm::{gemm, Parallelism};
560
561        match T::DTYPE {
562            DType::F16 | DType::F32 => {}
563            _ => crate::bail!("unsupported dtype for gemm"),
564        }
565
566        let (b, m, n, k) = self.0;
567        let lhs = &lhs[lhs_l.offset()..];
568        let rhs = &rhs[rhs_l.offset()..];
569
570        let lhs_strides = lhs_l.strides();
571        let rhs_strides = rhs_l.strides();
572        let rank = lhs_strides.len();
573        let lhs_cs = lhs_strides[rank - 1];
574        let lhs_rs = lhs_strides[rank - 2];
575
576        let rhs_cs = rhs_strides[rank - 1];
577        let rhs_rs = rhs_strides[rank - 2];
578
579        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
580        let c_skip: usize = m * n;
581
582        let dst_shape: crate::Shape = (m, n).into();
583        let dst_strides = dst_shape.stride_contiguous();
584        let dst_rs = dst_strides[0];
585        let dst_cs = dst_strides[1];
586
587        let num_threads = crate::utils::get_num_threads();
588        let parallelism =
589            if num_threads > 1 { Parallelism::Rayon(num_threads) } else { Parallelism::None };
590        for step in 0..b {
591            let lhs_p = &lhs[step * a_skip..];
592            let rhs_p = &rhs[step * b_skip..];
593            let dst_p = &mut dst[step * c_skip..];
594            unsafe {
595                gemm(
596                    /* m: usize = */ m,
597                    /* n: usize = */ n,
598                    /* k: usize = */ k,
599                    /* dst: *mut T = */ dst_p.as_mut_ptr(),
600                    /* dst_cs: isize = */ dst_cs as isize,
601                    /* dst_rs: isize = */ dst_rs as isize,
602                    /* read_dst: bool = */ false,
603                    /* lhs: *const T = */ lhs_p.as_ptr(),
604                    /* lhs_cs: isize = */ lhs_cs as isize,
605                    /* lhs_rs: isize = */ lhs_rs as isize,
606                    /* rhs: *const T = */ rhs_p.as_ptr(),
607                    /* rhs_cs: isize = */ rhs_cs as isize,
608                    /* rhs_rs: isize = */ rhs_rs as isize,
609                    /* alpha: T = */ T::zero(),
610                    /* beta: T = */ T::one(),
611                    /* conj_dst: bool = */ false,
612                    /* conj_lhs: bool = */ false,
613                    /* conj_rhs: bool = */ false,
614                    parallelism,
615                )
616            }
617        }
618        Ok(())
619    }
620}