zenu_matrix/
lib.rs

1use device::cpu::Cpu;
2use memory_pool::MemPool;
3
4pub mod concat;
5pub mod constructor;
6pub mod device;
7pub mod dim;
8pub mod index;
9pub mod matrix;
10pub mod matrix_blas;
11pub mod matrix_iter;
12pub mod nn;
13pub mod num;
14pub mod operation;
15pub mod shape_stride;
16pub mod slice;
17
18mod impl_ops;
19mod impl_serde;
20mod matrix_format;
21mod memory_pool;
22mod with_clousers;
23
24#[cfg(feature = "nvidia")]
25use device::nvidia::Nvidia;
26
27pub(crate) struct ZenuMatrixState {
28    pub(crate) is_mem_pool_used: bool,
29    pub(crate) cpu: MemPool<Cpu>,
30    #[cfg(feature = "nvidia")]
31    pub(crate) nvidia: MemPool<Nvidia>,
32}
33
34impl Default for ZenuMatrixState {
35    fn default() -> Self {
36        let use_mem_pool = std::env::var("ZENU_USE_MEMPOOL").unwrap_or("1".to_string()) == "1";
37        ZenuMatrixState {
38            is_mem_pool_used: use_mem_pool,
39            cpu: MemPool::default(),
40            #[cfg(feature = "nvidia")]
41            nvidia: MemPool::default(),
42        }
43    }
44}
45
46pub(crate) static ZENU_MATRIX_STATE: once_cell::sync::Lazy<ZenuMatrixState> =
47    once_cell::sync::Lazy::new(ZenuMatrixState::default);