1use crate::{DType, Error, Layout, Result};
2use half::{bf16, f16};
3use std::path::PathBuf;
4
5#[derive(Debug, Clone)]
6pub enum CpuStorage {
7 BF16(Vec<bf16>),
8 F16(Vec<f16>),
9 F32(Vec<f32>),
10 I32(Vec<i32>),
11 I64(Vec<i64>),
12}
13
14#[derive(Debug, Copy, Clone)]
16pub enum CpuStorageRef<'a> {
17 BF16(&'a [bf16]),
18 F16(&'a [f16]),
19 F32(&'a [f32]),
20 I32(&'a [i32]),
21 I64(&'a [i64]),
22}
23
24#[derive(Debug)]
25pub enum CpuStorageRefMut<'a> {
26 BF16(&'a mut [bf16]),
27 F16(&'a mut [f16]),
28 F32(&'a mut [f32]),
29 I32(&'a mut [i32]),
30 I64(&'a mut [i64]),
31}
32
33impl From<Vec<bf16>> for CpuStorage {
34 fn from(value: Vec<bf16>) -> Self {
35 Self::BF16(value)
36 }
37}
38
39impl From<Vec<f16>> for CpuStorage {
40 fn from(value: Vec<f16>) -> Self {
41 Self::F16(value)
42 }
43}
44
45impl From<Vec<f32>> for CpuStorage {
46 fn from(value: Vec<f32>) -> Self {
47 Self::F32(value)
48 }
49}
50
51impl From<Vec<i32>> for CpuStorage {
52 fn from(value: Vec<i32>) -> Self {
53 Self::I32(value)
54 }
55}
56
57impl From<Vec<i64>> for CpuStorage {
58 fn from(value: Vec<i64>) -> Self {
59 Self::I64(value)
60 }
61}
62
63impl CpuStorage {
64 pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
65 match self {
66 Self::BF16(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
67 Self::F16(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
68 Self::F32(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
69 Self::I32(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
70 Self::I64(s) => s.as_mut_ptr() as *mut std::ffi::c_void,
71 }
72 }
73
74 pub fn as_ptr(&mut self) -> *const std::ffi::c_void {
75 match self {
76 Self::BF16(s) => s.as_ptr() as *const std::ffi::c_void,
77 Self::F16(s) => s.as_ptr() as *const std::ffi::c_void,
78 Self::F32(s) => s.as_ptr() as *const std::ffi::c_void,
79 Self::I32(s) => s.as_ptr() as *const std::ffi::c_void,
80 Self::I64(s) => s.as_ptr() as *const std::ffi::c_void,
81 }
82 }
83
84 pub fn len(&self) -> usize {
85 match self {
86 Self::BF16(s) => s.len(),
87 Self::F16(s) => s.len(),
88 Self::F32(s) => s.len(),
89 Self::I32(s) => s.len(),
90 Self::I64(s) => s.len(),
91 }
92 }
93
94 pub fn is_empty(&self) -> bool {
95 self.len() == 0
96 }
97
98 pub fn dtype(&self) -> DType {
99 match self {
100 Self::BF16(_) => DType::BF16,
101 Self::F16(_) => DType::F16,
102 Self::F32(_) => DType::F32,
103 Self::I32(_) => DType::I32,
104 Self::I64(_) => DType::I64,
105 }
106 }
107
108 pub fn as_ref(&self) -> CpuStorageRef<'_> {
109 match self {
110 Self::BF16(v) => CpuStorageRef::BF16(v.as_slice()),
111 Self::F16(v) => CpuStorageRef::F16(v.as_slice()),
112 Self::F32(v) => CpuStorageRef::F32(v.as_slice()),
113 Self::I32(v) => CpuStorageRef::I32(v.as_slice()),
114 Self::I64(v) => CpuStorageRef::I64(v.as_slice()),
115 }
116 }
117
118 pub fn as_mut_ref(&mut self) -> CpuStorageRefMut<'_> {
119 match self {
120 Self::BF16(v) => CpuStorageRefMut::BF16(v.as_mut_slice()),
121 Self::F16(v) => CpuStorageRefMut::F16(v.as_mut_slice()),
122 Self::F32(v) => CpuStorageRefMut::F32(v.as_mut_slice()),
123 Self::I32(v) => CpuStorageRefMut::I32(v.as_mut_slice()),
124 Self::I64(v) => CpuStorageRefMut::I64(v.as_mut_slice()),
125 }
126 }
127
128 pub fn data<T: crate::WithDType>(&self) -> Result<&[T]> {
129 T::from_cpu_storage(self.as_ref())
130 }
131
132 pub fn data_mut<T: crate::WithDType>(&mut self) -> Result<&mut [T]> {
133 T::from_cpu_storage_mut(self.as_mut_ref())
134 }
135}
136
137impl CpuStorageRef<'_> {
138 pub fn dtype(&self) -> DType {
139 match self {
140 Self::BF16(_) => DType::BF16,
141 Self::F16(_) => DType::F16,
142 Self::F32(_) => DType::F32,
143 Self::I32(_) => DType::I32,
144 Self::I64(_) => DType::I64,
145 }
146 }
147
148 pub fn len(&self) -> usize {
149 match self {
150 Self::BF16(s) => s.len(),
151 Self::F16(s) => s.len(),
152 Self::F32(s) => s.len(),
153 Self::I32(s) => s.len(),
154 Self::I64(s) => s.len(),
155 }
156 }
157
158 pub fn is_empty(&self) -> bool {
159 self.len() == 0
160 }
161}
162
163impl<'a> From<&'a [bf16]> for CpuStorageRef<'a> {
164 fn from(value: &'a [bf16]) -> Self {
165 Self::BF16(value)
166 }
167}
168
169impl<'a> From<&'a [f16]> for CpuStorageRef<'a> {
170 fn from(value: &'a [f16]) -> Self {
171 Self::F16(value)
172 }
173}
174
175impl<'a> From<&'a [f32]> for CpuStorageRef<'a> {
176 fn from(value: &'a [f32]) -> Self {
177 Self::F32(value)
178 }
179}
180
181impl<'a> From<&'a [i32]> for CpuStorageRef<'a> {
182 fn from(value: &'a [i32]) -> Self {
183 Self::I32(value)
184 }
185}
186
187impl<'a> From<&'a [i64]> for CpuStorageRef<'a> {
188 fn from(value: &'a [i64]) -> Self {
189 Self::I64(value)
190 }
191}
192
193impl CpuStorageRefMut<'_> {
194 pub fn dtype(&self) -> DType {
195 match self {
196 Self::BF16(_) => DType::BF16,
197 Self::F16(_) => DType::F16,
198 Self::F32(_) => DType::F32,
199 Self::I32(_) => DType::I32,
200 Self::I64(_) => DType::I64,
201 }
202 }
203 pub fn len(&self) -> usize {
204 match self {
205 Self::BF16(s) => s.len(),
206 Self::F16(s) => s.len(),
207 Self::F32(s) => s.len(),
208 Self::I32(s) => s.len(),
209 Self::I64(s) => s.len(),
210 }
211 }
212
213 pub fn is_empty(&self) -> bool {
214 self.len() == 0
215 }
216}
217
218impl<'a> From<&'a mut [bf16]> for CpuStorageRefMut<'a> {
219 fn from(value: &'a mut [bf16]) -> Self {
220 Self::BF16(value)
221 }
222}
223
224impl<'a> From<&'a mut [f16]> for CpuStorageRefMut<'a> {
225 fn from(value: &'a mut [f16]) -> Self {
226 Self::F16(value)
227 }
228}
229
230impl<'a> From<&'a mut [f32]> for CpuStorageRefMut<'a> {
231 fn from(value: &'a mut [f32]) -> Self {
232 Self::F32(value)
233 }
234}
235
236impl<'a> From<&'a mut [i32]> for CpuStorageRefMut<'a> {
237 fn from(value: &'a mut [i32]) -> Self {
238 Self::I32(value)
239 }
240}
241
242impl<'a> From<&'a mut [i64]> for CpuStorageRefMut<'a> {
243 fn from(value: &'a mut [i64]) -> Self {
244 Self::I64(value)
245 }
246}
247
248#[derive(Clone, Copy, Debug)]
249pub struct CpuDevice;
250
251impl crate::Device for CpuDevice {
252 type Slice = CpuStorage;
253 type Func = Func;
254
255 unsafe fn allocate_uninit(&self, dtype: DType, len: usize) -> Result<Self::Slice> {
256 let slice = match dtype {
257 DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; len]),
258 DType::F16 => CpuStorage::F16(vec![f16::ZERO; len]),
259 DType::F32 => CpuStorage::F32(vec![0f32; len]),
260 DType::I32 => CpuStorage::I32(vec![0i32; len]),
261 DType::I64 => CpuStorage::I64(vec![0i64; len]),
262 };
263 Ok(slice)
264 }
265
266 fn synchronize(&self) -> Result<()> {
267 Ok(())
268 }
269
270 fn use_grid() -> bool {
271 false
272 }
273
274 fn compile(&self, kernel: &crate::lang::ssa::Kernel, name: Option<&str>) -> Result<Self::Func> {
275 let mut c_code = Vec::with_capacity(8192);
276 let pid = std::process::id();
279 let kernel_id = KernelId::new().as_usize();
280 let func_name = match name {
281 Some(name) => format!("ugc_{name}_{pid}_{kernel_id}"),
282 None => format!("ugc_{pid}_{kernel_id}"),
283 };
284 crate::cpu_code_gen::gen(&mut c_code, &func_name, kernel)?;
285 self.compile_c(&c_code, func_name)
286 }
287
288 fn run(&self, f: &Self::Func, args: &mut [&mut Self::Slice]) -> Result<()> {
289 use libloading::Symbol as S;
290 use std::ffi::c_void;
291
292 let func_name = f.func_name.as_bytes();
293 match args {
296 [] => {
297 let symbol: S<extern "C" fn()> = unsafe { f.lib.get(func_name)? };
298 symbol()
299 }
300 [a1] => {
301 let symbol: S<extern "C" fn(*mut c_void)> = unsafe { f.lib.get(func_name)? };
302 symbol(a1.as_mut_ptr())
303 }
304 [a1, a2] => {
305 let symbol: S<extern "C" fn(*mut c_void, *mut c_void)> =
306 unsafe { f.lib.get(func_name)? };
307 symbol(a1.as_mut_ptr(), a2.as_mut_ptr())
308 }
309 [a1, a2, a3] => {
310 let symbol: S<extern "C" fn(*mut c_void, *mut c_void, *mut c_void)> =
311 unsafe { f.lib.get(func_name)? };
312 symbol(a1.as_mut_ptr(), a2.as_mut_ptr(), a3.as_mut_ptr())
313 }
314 [a1, a2, a3, a4] => {
315 let symbol: S<extern "C" fn(*mut c_void, *mut c_void, *mut c_void, *mut c_void)> =
316 unsafe { f.lib.get(func_name)? };
317 symbol(a1.as_mut_ptr(), a2.as_mut_ptr(), a3.as_mut_ptr(), a4.as_mut_ptr())
318 }
319 [a1, a2, a3, a4, a5] => {
320 let symbol: S<
321 extern "C" fn(*mut c_void, *mut c_void, *mut c_void, *mut c_void, *mut c_void),
322 > = unsafe { f.lib.get(func_name)? };
323 symbol(
324 a1.as_mut_ptr(),
325 a2.as_mut_ptr(),
326 a3.as_mut_ptr(),
327 a4.as_mut_ptr(),
328 a5.as_mut_ptr(),
329 )
330 }
331 _ => crate::bail!("unsupported number of args for kernel {}", args.len()),
332 }
333 Ok(())
334 }
335
336 fn matmul(
337 &self,
338 dst: &mut Self::Slice,
339 lhs: &Self::Slice,
340 rhs: &Self::Slice,
341 bmnk: (usize, usize, usize, usize),
342 lhs_l: &Layout,
343 rhs_l: &Layout,
344 ) -> Result<()> {
345 use CpuStorage::{F16, F32};
346 let mm = MatMul(bmnk);
347 let (dst_dt, lhs_dt, rhs_dt) = (dst.dtype(), lhs.dtype(), rhs.dtype());
348 match (dst, lhs, rhs) {
349 (F16(dst), F16(lhs), F16(rhs)) => mm.gemm(dst, lhs, lhs_l, rhs, rhs_l)?,
350 (F32(dst), F32(lhs), F32(rhs)) => mm.gemm(dst, lhs, lhs_l, rhs, rhs_l)?,
351 _ => {
352 crate::bail!(
353 "incorrect dtypes for matmul, dst: {dst_dt:?}, lhs: {lhs_dt:?}, rhs: {rhs_dt:?}"
354 )
355 }
356 }
357 Ok(())
358 }
359}
360
361impl crate::Slice for CpuStorage {
362 type Device = CpuDevice;
363
364 fn len(&self) -> usize {
365 CpuStorage::len(self)
366 }
367
368 fn dtype(&self) -> crate::DType {
369 CpuStorage::dtype(self)
370 }
371
372 fn device(&self) -> &Self::Device {
373 &CpuDevice
374 }
375
376 fn copy_host_to_device<DT: crate::WithDType>(&mut self, src: &[DT]) -> Result<()> {
377 use CpuStorage as S;
378 use CpuStorageRef as C;
379 let dtype = self.dtype();
380 if src.len() != self.len() {
381 crate::bail!("dtoh len mismatch, dst {}, len {}", self.len(), src.len())
382 }
383 match (self, DT::to_cpu_storage(src)) {
384 (S::BF16(dst), C::BF16(src)) => dst.copy_from_slice(src),
385 (S::F16(dst), C::F16(src)) => dst.copy_from_slice(src),
386 (S::F32(dst), C::F32(src)) => dst.copy_from_slice(src),
387 (S::I32(dst), C::I32(src)) => dst.copy_from_slice(src),
388 (S::I64(dst), C::I64(src)) => dst.copy_from_slice(src),
389 (_, _) => {
390 crate::bail!("htod dtype mismatch, dst {dtype:?}, src {:?}", DT::DTYPE)
391 }
392 }
393 Ok(())
394 }
395
396 fn copy_device_to_host<DT: crate::WithDType>(&self, dst: &mut [DT]) -> Result<()> {
397 use CpuStorage as S;
398 use CpuStorageRefMut as C;
399 let dtype = self.dtype();
400 if dst.len() != self.len() {
401 crate::bail!("dtoh len mismatch, dst {}, len {}", dst.len(), self.len())
402 }
403 match (self, DT::to_cpu_storage_mut(dst)) {
404 (S::BF16(src), C::BF16(dst)) => dst.copy_from_slice(src),
405 (S::F16(src), C::F16(dst)) => dst.copy_from_slice(src),
406 (S::F32(src), C::F32(dst)) => dst.copy_from_slice(src),
407 (S::I32(src), C::I32(dst)) => dst.copy_from_slice(src),
408 (S::I64(src), C::I64(dst)) => dst.copy_from_slice(src),
409 (_, _) => crate::bail!("dtoh dtype mismatch, dst {:?}, src {dtype:?}", DT::DTYPE),
410 }
411 Ok(())
412 }
413}
414
415#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
416pub struct KernelId(usize);
417
418impl KernelId {
419 pub(crate) fn new() -> Self {
420 use std::sync::atomic;
422 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
423 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
424 }
425
426 pub fn as_usize(&self) -> usize {
427 self.0
428 }
429}
430
431pub struct Func {
432 func_name: String,
433 lib: libloading::Library,
434}
435
436impl Func {
437 pub fn name(&self) -> &str {
438 self.func_name.as_str()
439 }
440
441 #[allow(clippy::missing_safety_doc)]
442 pub unsafe fn run0(&self) -> Result<()> {
443 let func_name = self.func_name.as_bytes();
444 let symbol: libloading::Symbol<unsafe extern "C" fn()> = self.lib.get(func_name)?;
445 symbol();
446 Ok(())
447 }
448
449 #[allow(clippy::missing_safety_doc)]
450 pub unsafe fn run3<T>(&self, v1: &mut [T], v2: &mut [T], v3: &mut [T]) -> Result<()> {
451 use std::ffi::c_void;
452
453 let func_name = self.func_name.as_bytes();
454 let symbol: libloading::Symbol<
455 unsafe extern "C" fn(*mut c_void, *mut c_void, *mut c_void),
456 > = self.lib.get(func_name)?;
457 symbol(
458 v1.as_mut_ptr() as *mut c_void,
459 v2.as_mut_ptr() as *mut c_void,
460 v3.as_mut_ptr() as *mut c_void,
461 );
462 Ok(())
463 }
464}
465
466impl crate::CpuDevice {
467 pub fn compile_c(&self, c_code: &[u8], func_name: String) -> Result<Func> {
468 fn compile_inner(
469 c_code: &[u8],
470 func_name: String,
471 tmp_c: &PathBuf,
472 tmp_so: &PathBuf,
473 ) -> Result<Func> {
474 std::fs::write(tmp_c, c_code)?;
475 let output = std::process::Command::new("gcc")
477 .arg(tmp_c)
478 .args([
479 "-shared",
480 "-lm",
481 "-O3",
482 "-march=native",
483 "-ffast-math",
484 "-fomit-frame-pointer",
485 "-o",
486 ])
487 .arg(tmp_so)
488 .output()?;
489
490 if !output.status.success() {
491 crate::bail!(
492 "compilation failed\nstdout:\n{}\nstderr:{}",
493 String::from_utf8_lossy(&output.stdout),
494 String::from_utf8_lossy(&output.stderr)
495 )
496 }
497 let lib = unsafe { libloading::Library::new(tmp_so)? };
498 Ok(Func { func_name, lib })
499 }
500
501 let tmp_dir = std::env::temp_dir();
502 let tmp_c = tmp_dir.join(format!("{func_name}.c"));
503 let tmp_so = tmp_dir.join(format!("{func_name}.so"));
504 let result = compile_inner(c_code, func_name, &tmp_c, &tmp_so);
505 if !crate::utils::KEEP_TMP.with(|b| *b) {
507 let _ = std::fs::remove_file(tmp_c);
508 let _ = std::fs::remove_file(tmp_so);
509 }
510 result
511 }
512}
513
514pub struct MatMul((usize, usize, usize, usize));
515
516impl MatMul {
517 fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
518 Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
519 lhs_l: lhs_l.clone(),
520 rhs_l: rhs_l.clone(),
521 bmnk: self.0,
522 msg,
523 }))
524 .bt()
525 }
526
527 fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
528 let lhs_stride = lhs_l.strides();
529 let rhs_stride = rhs_l.strides();
530 let rank = lhs_stride.len();
531 let (_b, m, n, k) = self.0;
532 let a_skip: usize = match lhs_stride[..rank - 2] {
533 [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
534 [_, stride] if lhs_l.dims()[0] == 1 => stride,
535 [stride, _] if lhs_l.dims()[1] == 1 => stride,
536 [stride] => stride,
537 [] => m * k,
538 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
539 };
540 let b_skip: usize = match rhs_stride[..rank - 2] {
541 [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
542 [_, stride] if rhs_l.dims()[0] == 1 => stride,
543 [stride, _] if rhs_l.dims()[1] == 1 => stride,
544 [stride] => stride,
545 [] => n * k,
546 _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
547 };
548 Ok((a_skip, b_skip))
549 }
550
551 pub fn gemm<T: crate::WithDType>(
552 &self,
553 dst: &mut [T],
554 lhs: &[T],
555 lhs_l: &Layout,
556 rhs: &[T],
557 rhs_l: &Layout,
558 ) -> Result<()> {
559 use gemm::{gemm, Parallelism};
560
561 match T::DTYPE {
562 DType::F16 | DType::F32 => {}
563 _ => crate::bail!("unsupported dtype for gemm"),
564 }
565
566 let (b, m, n, k) = self.0;
567 let lhs = &lhs[lhs_l.offset()..];
568 let rhs = &rhs[rhs_l.offset()..];
569
570 let lhs_strides = lhs_l.strides();
571 let rhs_strides = rhs_l.strides();
572 let rank = lhs_strides.len();
573 let lhs_cs = lhs_strides[rank - 1];
574 let lhs_rs = lhs_strides[rank - 2];
575
576 let rhs_cs = rhs_strides[rank - 1];
577 let rhs_rs = rhs_strides[rank - 2];
578
579 let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
580 let c_skip: usize = m * n;
581
582 let dst_shape: crate::Shape = (m, n).into();
583 let dst_strides = dst_shape.stride_contiguous();
584 let dst_rs = dst_strides[0];
585 let dst_cs = dst_strides[1];
586
587 let num_threads = crate::utils::get_num_threads();
588 let parallelism =
589 if num_threads > 1 { Parallelism::Rayon(num_threads) } else { Parallelism::None };
590 for step in 0..b {
591 let lhs_p = &lhs[step * a_skip..];
592 let rhs_p = &rhs[step * b_skip..];
593 let dst_p = &mut dst[step * c_skip..];
594 unsafe {
595 gemm(
596 m,
597 n,
598 k,
599 dst_p.as_mut_ptr(),
600 dst_cs as isize,
601 dst_rs as isize,
602 false,
603 lhs_p.as_ptr(),
604 lhs_cs as isize,
605 lhs_rs as isize,
606 rhs_p.as_ptr(),
607 rhs_cs as isize,
608 rhs_rs as isize,
609 T::zero(),
610 T::one(),
611 false,
612 false,
613 false,
614 parallelism,
615 )
616 }
617 }
618 Ok(())
619 }
620}