1use device::cpu::Cpu;
2use memory_pool::MemPool;
3
4pub mod concat;
5pub mod constructor;
6pub mod device;
7pub mod dim;
8pub mod index;
9pub mod matrix;
10pub mod matrix_blas;
11pub mod matrix_iter;
12pub mod nn;
13pub mod num;
14pub mod operation;
15pub mod shape_stride;
16pub mod slice;
17
18mod impl_ops;
19mod impl_serde;
20mod matrix_format;
21mod memory_pool;
22mod with_clousers;
23
24#[cfg(feature = "nvidia")]
25use device::nvidia::Nvidia;
26
27pub(crate) struct ZenuMatrixState {
28 pub(crate) is_mem_pool_used: bool,
29 pub(crate) cpu: MemPool<Cpu>,
30 #[cfg(feature = "nvidia")]
31 pub(crate) nvidia: MemPool<Nvidia>,
32}
33
34impl Default for ZenuMatrixState {
35 fn default() -> Self {
36 let use_mem_pool = std::env::var("ZENU_USE_MEMPOOL").unwrap_or("1".to_string()) == "1";
37 ZenuMatrixState {
38 is_mem_pool_used: use_mem_pool,
39 cpu: MemPool::default(),
40 #[cfg(feature = "nvidia")]
41 nvidia: MemPool::default(),
42 }
43 }
44}
45
46pub(crate) static ZENU_MATRIX_STATE: once_cell::sync::Lazy<ZenuMatrixState> =
47 once_cell::sync::Lazy::new(ZenuMatrixState::default);