zenu_matrix/device/
cpu.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{memory_pool::MemPoolError, num::Num, ZENU_MATRIX_STATE};
4
5use super::{Device, DeviceBase};
6
7#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
8pub struct Cpu;
9
10impl DeviceBase for Cpu {
11    fn raw_drop_ptr<T>(ptr: *mut T) {
12        unsafe { libc::free(ptr.cast::<::libc::c_void>()) }
13    }
14
15    fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError> {
16        let state = &ZENU_MATRIX_STATE;
17        state.cpu.try_free(ptr)
18    }
19
20    #[expect(clippy::not_unsafe_ptr_arg_deref)]
21    fn clone_ptr<T>(ptr: *const T, len: usize) -> *mut T {
22        let mut vec = Vec::with_capacity(len);
23        for i in 0..len {
24            vec.push(unsafe { ptr.add(i).read() });
25        }
26        let ptr = vec.as_mut_ptr();
27        std::mem::forget(vec);
28        ptr
29    }
30
31    #[expect(clippy::not_unsafe_ptr_arg_deref)]
32    fn assign_item<T>(ptr: *mut T, offset: usize, value: T) {
33        unsafe {
34            ptr.add(offset).write(value);
35        }
36    }
37
38    #[expect(clippy::not_unsafe_ptr_arg_deref)]
39    fn get_item<T>(ptr: *const T, offset: usize) -> T {
40        unsafe { ptr.add(offset).read() }
41    }
42
43    fn from_vec<T>(mut vec: Vec<T>) -> *mut T {
44        let ptr = vec.as_mut_ptr().cast::<T>();
45        std::mem::forget(vec);
46        ptr
47    }
48
49    fn zeros<T: Num>(len: usize) -> *mut T {
50        use cblas::{dscal, sscal};
51        let ptr = Self::alloc(len * std::mem::size_of::<T>())
52            .unwrap()
53            .cast::<T>();
54        if T::is_f32() {
55            let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast(), 1) };
56            unsafe { sscal(i32::try_from(len).unwrap(), 0.0, slice, 1) };
57        } else {
58            let slice = unsafe { std::slice::from_raw_parts_mut(ptr.cast(), 1) };
59            unsafe { dscal(i32::try_from(len).unwrap(), 0.0, slice, 1) };
60        }
61        ptr
62    }
63
64    fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String> {
65        let ptr = unsafe { libc::malloc(num_bytes) };
66        if ptr.is_null() {
67            Err("null pointer".to_string())
68        } else {
69            Ok(ptr.cast())
70        }
71    }
72
73    fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> {
74        let state = &ZENU_MATRIX_STATE;
75        state.cpu.try_alloc(num_bytes)
76    }
77}
78
79impl Device for Cpu {}