zenu_matrix/device/
mod.rs1use serde::Serialize;
2
3use crate::{
4 memory_pool::MemPoolError,
5 nn::{batch_norm::BatchNormalization, conv2d::Conv2d, dropout::Dropout, pool2d::Pool2dImpl},
6 num::Num,
7 operation::{
8 asum::Asum,
9 basic_operations::{
10 AbsOps, AcosOps, AddOps, AsinOps, AtanOps, CosOps, CoshOps, DivOps, ExpOps, LogOps,
11 MulOps, PowOws, SinOps, SinhOps, SqrtOps, SubOps, TanOps, TanhOps,
12 },
13 clip::ClipOps,
14 copy_from::CopyBlas,
15 max::MaxIdx,
16 mul::Gemm,
17 relu::ReluOps,
18 },
19 ZENU_MATRIX_STATE,
20};
21
22pub mod cpu;
23
24#[cfg(feature = "nvidia")]
25pub mod nvidia;
26
27#[expect(clippy::module_name_repetitions)]
28pub trait DeviceBase: Copy + Default + Serialize + 'static {
29 fn drop_ptr<T>(ptr: *mut T) {
30 let state = &ZENU_MATRIX_STATE;
31 if state.is_mem_pool_used {
32 let result = Self::mem_pool_drop_ptr(ptr.cast());
33 if result.is_err() {
34 Self::raw_drop_ptr(ptr);
35 }
36 } else {
37 Self::raw_drop_ptr(ptr);
38 }
39 }
40 #[expect(clippy::missing_errors_doc)]
41 fn mem_pool_drop_ptr(ptr: *mut u8) -> Result<(), MemPoolError>;
42 fn raw_drop_ptr<T>(ptr: *mut T);
43 fn clone_ptr<T>(ptr: *const T, len: usize) -> *mut T;
44 fn assign_item<T: Num>(ptr: *mut T, offset: usize, value: T);
45 fn get_item<T: Num>(ptr: *const T, offset: usize) -> T;
46 fn from_vec<T: Num>(vec: Vec<T>) -> *mut T;
47 fn zeros<T: Num>(len: usize) -> *mut T;
48 #[expect(clippy::missing_errors_doc)]
49 fn alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError> {
50 let state = &ZENU_MATRIX_STATE;
51 if state.is_mem_pool_used {
52 Self::mem_pool_alloc(num_bytes)
53 } else {
54 Self::raw_alloc(num_bytes).map_err(|_| MemPoolError::DeviceMallocError)
55 }
56 }
57 #[expect(clippy::missing_errors_doc)]
58 fn mem_pool_alloc(num_bytes: usize) -> Result<*mut u8, MemPoolError>;
59 #[expect(clippy::missing_errors_doc)]
60 fn raw_alloc(num_bytes: usize) -> Result<*mut u8, String>;
61}
62
63pub trait Device:
64 DeviceBase
65 + CopyBlas
66 + AddOps
67 + SubOps
68 + MulOps
69 + DivOps
70 + Asum
71 + ClipOps
72 + SinOps
73 + CosOps
74 + TanOps
75 + AsinOps
76 + AcosOps
77 + AtanOps
78 + SinhOps
79 + CoshOps
80 + TanhOps
81 + AbsOps
82 + SqrtOps
83 + ExpOps
84 + LogOps
85 + MaxIdx
86 + ReluOps
87 + Gemm
88 + PowOws
89 + BatchNormalization
90 + Conv2d
91 + Sized
92 + Pool2dImpl
93 + Dropout
94 + Send
95 + Sync
96 + 'static
97{
98}