zenu_matrix/device/
mod.rs

1use serde::Serialize;
2
3use crate::{
4    memory_pool::MemPoolError,
5    nn::{batch_norm::BatchNormalization, conv2d::Conv2d, dropout::Dropout, pool2d::Pool2dImpl},
6    num::Num,
7    operation::{
8        asum::Asum,
9        basic_operations::{
10            AbsOps, AcosOps, AddOps, AsinOps, AtanOps, CosOps, CoshOps, DivOps, ExpOps, LogOps,
11            MulOps, PowOws, SinOps, SinhOps, SqrtOps, SubOps, TanOps, TanhOps,
12        },
13        clip::ClipOps,
14        copy_from::CopyBlas,
15        max::MaxIdx,
16        mul::Gemm,
17        relu::ReluOps,
18    },
19    ZENU_MATRIX_STATE,
20};
21
22pub mod cpu;
23
24#[cfg(feature = "nvidia")]
25pub mod nvidia;
26
27#[expect(clippy::module_name_repetitions)]
28pub trait DeviceBase: Copy + Default + Serialize + 'static {
29    fn drop_ptr<T>(ptr: *mut T) {
30        let state = &ZENU_MATRIX_STATE;
31        if state.is_mem_pool_used {
32            let result = Self::mem_pool_drop_ptr(ptr.cast());
33            if result.is_err() {
34                Self::raw_drop_ptr(ptr);
35            }
36        } else {
37            Self::raw_drop_ptr(ptr);
38        }
39    }
40    #[expect(clippy::missing_errors_doc)]
41    fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError>;
42    fn raw_drop_ptr<T>(ptr: *mut T);
43    fn clone_ptr<T>(ptr: *const T, len: usize) -> *mut T;
44    fn assign_item<T: Num>(ptr: *mut T, offset: usize, value: T);
45    fn get_item<T: Num>(ptr: *const T, offset: usize) -> T;
46    fn from_vec<T: Num>(vec: Vec<T>) -> *mut T;
47    fn zeros<T: Num>(len: usize) -> *mut T;
48    #[expect(clippy::missing_errors_doc)]
49    fn alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> {
50        let state = &ZENU_MATRIX_STATE;
51        if state.is_mem_pool_used {
52            Self::mem_pool_alloc(num_bytes)
53        } else {
54            Self::raw_alloc(num_bytes).map_err(|_| MemPoolError::DeviceMallocError)
55        }
56    }
57    #[expect(clippy::missing_errors_doc)]
58    fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError>;
59    #[expect(clippy::missing_errors_doc)]
60    fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String>;
61}
62
63pub trait Device:
64    DeviceBase
65    + CopyBlas
66    + AddOps
67    + SubOps
68    + MulOps
69    + DivOps
70    + Asum
71    + ClipOps
72    + SinOps
73    + CosOps
74    + TanOps
75    + AsinOps
76    + AcosOps
77    + AtanOps
78    + SinhOps
79    + CoshOps
80    + TanhOps
81    + AbsOps
82    + SqrtOps
83    + ExpOps
84    + LogOps
85    + MaxIdx
86    + ReluOps
87    + Gemm
88    + PowOws
89    + BatchNormalization
90    + Conv2d
91    + Sized
92    + Pool2dImpl
93    + Dropout
94    + Send
95    + Sync
96    + 'static
97{
98}