zenu_matrix/nn/
mod.rs

1use crate::device::DeviceBase;
2
3pub mod batch_norm;
4pub mod col2im;
5pub mod conv2d;
6pub mod dropout;
7pub mod im2col;
8pub mod pool2d;
9
10#[cfg(feature = "nvidia")]
11pub mod rnn;
12
13#[expect(unused)]
14pub(crate) struct NNCache<D: DeviceBase> {
15    pub(crate) bytes: usize,
16    pub(crate) ptr: *mut u8,
17    _device: std::marker::PhantomData<D>,
18}
19
20impl<D: DeviceBase> NNCache<D> {
21    pub(crate) fn new(bytes: usize) -> Self {
22        let ptr = D::alloc(bytes).unwrap();
23        assert!(!ptr.is_null(), "Failed to allocate memory");
24        Self {
25            bytes,
26            ptr,
27            _device: std::marker::PhantomData,
28        }
29    }
30}
31
32impl<D: DeviceBase> Drop for NNCache<D> {
33    fn drop(&mut self) {
34        D::drop_ptr(self.ptr);
35    }
36}