zenu_matrix/device/
cpu.rs1use serde::{Deserialize, Serialize};
2
3use crate::{memory_pool::MemPoolError, num::Num, ZENU_MATRIX_STATE};
4
5use super::{Device, DeviceBase};
6
7#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
8pub struct Cpu;
9
10impl DeviceBase for Cpu {
11 fn raw_drop_ptr<T>(ptr: *mut T) {
12 unsafe { libc::free(ptr.cast::<::libc::c_void>()) }
13 }
14
15 fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError> {
16 let state = &ZENU_MATRIX_STATE;
17 state.cpu.try_free(ptr)
18 }
19
20 #[expect(clippy::not_unsafe_ptr_arg_deref)]
21 fn clone_ptr<T>(ptr: *const T, len: usize) -> *mut T {
22 let mut vec = Vec::with_capacity(len);
23 for i in 0..len {
24 vec.push(unsafe { ptr.add(i).read() });
25 }
26 let ptr = vec.as_mut_ptr();
27 std::mem::forget(vec);
28 ptr
29 }
30
31 #[expect(clippy::not_unsafe_ptr_arg_deref)]
32 fn assign_item<T>(ptr: *mut T, offset: usize, value: T) {
33 unsafe {
34 ptr.add(offset).write(value);
35 }
36 }
37
38 #[expect(clippy::not_unsafe_ptr_arg_deref)]
39 fn get_item<T>(ptr: *const T, offset: usize) -> T {
40 unsafe { ptr.add(offset).read() }
41 }
42
43 fn from_vec<T>(mut vec: Vec<T>) -> *mut T {
44 let ptr = vec.as_mut_ptr().cast::<T>();
45 std::mem::forget(vec);
46 ptr
47 }
48
49 fn zeros<T: Num>(len: usize) -> *mut T {
50 use cblas::{dscal, sscal};
51 let ptr = Self::alloc(len * std::mem::size_of::<T>())
52 .unwrap()
53 .cast::<T>();
54 if T::is_f32() {
55 let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast(), 1) };
56 unsafe { sscal(i32::try_from(len).unwrap(), 0.0, slice, 1) };
57 } else {
58 let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast(), 1) };
59 unsafe { dscal(i32::try_from(len).unwrap(), 0.0, slice, 1) };
60 }
61 ptr
62 }
63
64 fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String> {
65 let ptr = unsafe { libc::malloc(num_bytes) };
66 if ptr.is_null() {
67 Err("null pointer".to_string())
68 } else {
69 Ok(ptr.cast())
70 }
71 }
72
73 fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> {
74 let state = &ZENU_MATRIX_STATE;
75 state.cpu.try_alloc(num_bytes)
76 }
77}
78
79impl Device for Cpu {}