zenu_matrix/nn/
dropout.rs

1use crate::{
2    device::{cpu::Cpu, DeviceBase},
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Owned, Ref, Repr},
5    num::Num,
6};
7
8use rand::seq::SliceRandom;
9#[cfg(feature = "nvidia")]
10use zenu_cuda::cudnn::dropout::DropoutConfig;
11
12#[cfg(feature = "nvidia")]
13use crate::device::nvidia::Nvidia;
14
15#[cfg(feature = "nvidia")]
16use super::NNCache;
17
18pub struct DropoutState<T: Num, D: DeviceBase> {
19    pub rate: f32,
20    pub state: Option<Matrix<Owned<T>, DimDyn, Cpu>>,
21    #[cfg(feature = "nvidia")]
22    gpu_state: Option<DropoutConfig<T>>,
23    #[cfg(feature = "nvidia")]
24    state_cache: Option<NNCache<D>>,
25    #[cfg(feature = "nvidia")]
26    reserve_space_cache: Option<NNCache<D>>,
27    _device: std::marker::PhantomData<D>,
28}
29
30impl<T: Num, D: DeviceBase> DropoutState<T, D> {
31    #[must_use]
32    pub fn new(rate: f32) -> Self {
33        Self {
34            rate,
35            state: None,
36            #[cfg(feature = "nvidia")]
37            gpu_state: None,
38            #[cfg(feature = "nvidia")]
39            state_cache: None,
40            #[cfg(feature = "nvidia")]
41            reserve_space_cache: None,
42            _device: std::marker::PhantomData,
43        }
44    }
45
46    #[expect(clippy::missing_panics_doc)]
47    pub fn gpu_init(&mut self, shape: DimDyn) {
48        #[cfg(feature = "nvidia")]
49        {
50            let gpu_state = DropoutConfig::new(shape.slice()).unwrap();
51            let state_size = gpu_state.get_state_size();
52            let reserve_space_size = gpu_state.get_reserve_space_size();
53            let cache = NNCache::<D>::new(state_size);
54            let reserve_space_cache = NNCache::<D>::new(reserve_space_size);
55            gpu_state.set(self.rate, 0, cache.ptr.cast()).unwrap();
56            self.gpu_state = Some(gpu_state);
57            self.state_cache = Some(cache);
58            self.reserve_space_cache = Some(reserve_space_cache);
59        }
60        #[cfg(not(feature = "nvidia"))]
61        {
62            panic!("GPU support is not enabled");
63        }
64    }
65}
66
67#[expect(
68    clippy::cast_precision_loss,
69    clippy::cast_sign_loss,
70    clippy::cast_possible_truncation
71)]
72fn dropout_mask_inner<T>(input: &mut [T], dropout_ratio: f32)
73where
74    T: Copy + std::ops::Mul<T, Output = T> + std::ops::AddAssign + Num,
75{
76    assert!(
77        (0.0..=1.0).contains(&dropout_ratio),
78        "Dropout ratio must be between 0 and 1"
79    );
80
81    let num_zeros_item = (input.len() as f32 * dropout_ratio) as usize;
82    let mut rng = rand::thread_rng();
83
84    let mut indices: Vec<usize> = (0..input.len()).collect();
85    indices.shuffle(&mut rng);
86    for i in 0..num_zeros_item {
87        input[indices[i]] = T::zero();
88    }
89}
90
91#[expect(clippy::needless_pass_by_value)]
92fn dropout_mask<T>(input: Matrix<Ref<&mut T>, DimDyn, Cpu>, dropout_ratio: f32)
93where
94    T: Copy + std::ops::Mul<T, Output = T> + std::ops::AddAssign + Num,
95{
96    let input_slice = input.as_mut_slice();
97    dropout_mask_inner(input_slice, dropout_ratio);
98}
99
100pub trait Dropout: DeviceBase {
101    fn dropout<T: Num>(
102        x: &Matrix<Ref<&T>, DimDyn, Self>,
103        state: &mut DropoutState<T, Self>,
104    ) -> Matrix<Owned<T>, DimDyn, Self>;
105
106    fn dropout_grad<T: Num>(
107        dy: &Matrix<Ref<&T>, DimDyn, Self>,
108        state: &DropoutState<T, Self>,
109    ) -> Matrix<Owned<T>, DimDyn, Self>;
110}
111
112impl Dropout for Cpu {
113    fn dropout<T: Num>(
114        x: &Matrix<Ref<&T>, DimDyn, Self>,
115        state: &mut DropoutState<T, Self>,
116    ) -> Matrix<Owned<T>, DimDyn, Self> {
117        let rate = state.rate;
118        let num_elm = x.shape().num_elm();
119        let mask = {
120            let mut mask = Matrix::ones([num_elm]);
121            dropout_mask(mask.to_ref_mut(), rate);
122            mask
123        };
124        let grad_ratio = T::one() / T::from(1.0 - rate).unwrap();
125        let y = x.reshape([num_elm]) * mask.to_ref() * grad_ratio;
126        state.state = Some(mask);
127        y.reshape_no_alloc_owned(x.shape())
128    }
129
130    fn dropout_grad<T: Num>(
131        dy: &Matrix<Ref<&T>, DimDyn, Self>,
132        state: &DropoutState<T, Self>,
133    ) -> Matrix<Owned<T>, DimDyn, Self> {
134        let rate = state.rate;
135        let mask = state.state.as_ref().unwrap();
136        let dy_original_shape = dy.shape();
137        let dy = dy.reshape([dy.shape().num_elm()]);
138        let grad_ratio = T::one() / T::from(1.0 - rate).unwrap();
139        let dx = dy.to_ref() * mask.to_ref() * grad_ratio;
140        dx.reshape_no_alloc_owned(dy_original_shape)
141    }
142}
143
144#[cfg(feature = "nvidia")]
145impl Dropout for Nvidia {
146    fn dropout<T: Num>(
147        x: &Matrix<Ref<&T>, DimDyn, Self>,
148        state: &mut DropoutState<T, Self>,
149    ) -> Matrix<Owned<T>, DimDyn, Self> {
150        match state.gpu_state {
151            Some(_) => {}
152            None => {
153                state.gpu_init(x.shape());
154            }
155        };
156
157        assert!(
158            state.reserve_space_cache.is_some(),
159            "Reserve space cache is not initialized"
160        );
161
162        let mut y = Matrix::<Owned<T>, _, _>::alloc(x.shape().slice());
163        let space_cache = state.reserve_space_cache.as_ref().unwrap();
164        let spcace_cache_ptr = space_cache.ptr.cast();
165
166        {
167            let y_mut_ref = y.to_ref_mut();
168            state
169                .gpu_state
170                .as_ref()
171                .unwrap()
172                .forward(
173                    x.as_ptr().cast(),
174                    y_mut_ref.as_mut_ptr().cast(),
175                    spcace_cache_ptr,
176                )
177                .unwrap();
178        }
179        y
180    }
181
182    fn dropout_grad<T: Num>(
183        dy: &Matrix<Ref<&T>, DimDyn, Self>,
184        state: &DropoutState<T, Self>,
185    ) -> Matrix<Owned<T>, DimDyn, Self> {
186        let gpu_state = state.gpu_state.as_ref().unwrap();
187        let space_cache = state.reserve_space_cache.as_ref().unwrap();
188
189        let mut dx = Matrix::<Owned<T>, _, _>::alloc(dy.shape().slice());
190
191        {
192            let dx_mut_ref = dx.to_ref_mut();
193            gpu_state
194                .backward(
195                    dy.as_ptr().cast(),
196                    dx_mut_ref.as_mut_ptr().cast(),
197                    space_cache.ptr.cast(),
198                )
199                .unwrap();
200        }
201        dx
202    }
203}
204
205#[expect(clippy::missing_panics_doc)]
206pub fn dropout<R, D>(
207    x: &Matrix<R, DimDyn, D>,
208    state: &mut DropoutState<R::Item, D>,
209) -> Matrix<Owned<R::Item>, DimDyn, D>
210where
211    R: Repr,
212    D: Dropout + DeviceBase,
213{
214    assert!(
215        (x.shape().len() == 2) || (x.shape().len() == 4),
216        "Only 2D and 4D tensors are supported"
217    );
218    D::dropout(&x.to_ref(), state)
219}
220
221#[expect(clippy::missing_panics_doc)]
222#[must_use]
223pub fn dropout_grad<R, D>(
224    dy: &Matrix<R, DimDyn, D>,
225    state: &DropoutState<R::Item, D>,
226) -> Matrix<Owned<R::Item>, DimDyn, D>
227where
228    R: Repr,
229    D: Dropout + DeviceBase,
230{
231    assert!(
232        (dy.shape().len() == 2) || (dy.shape().len() == 4),
233        "Only 2D and 4D tensors are supported"
234    );
235    D::dropout_grad(&dy.to_ref(), state)
236}
237
238#[cfg(test)]
239mod dropout {
240    use zenu_test::run_mat_test;
241
242    use crate::{
243        device::{cpu::Cpu, Device},
244        matrix::Matrix,
245    };
246
247    use super::{dropout, dropout_grad, DropoutState};
248
249    #[expect(clippy::float_cmp)]
250    fn dropout_4d<D: Device>() {
251        let mut state = DropoutState::<f32, D>::new(0.8);
252        let x = crate::matrix::Matrix::from_vec(
253            vec![
254                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
255            ],
256            [1, 2, 2, 3],
257        );
258
259        let y = dropout(&x, &mut state);
260        let y_cpu = y.clone().to::<Cpu>();
261        let y_cpu_ref = y_cpu.to_ref();
262        let y_cpu_slice = y_cpu_ref.as_slice_unchecked();
263        let zero_indexed = y_cpu_slice.iter().map(|x| *x == 0.).collect::<Vec<bool>>();
264        let y_grad = Matrix::ones_like(&y);
265        let x_grad = dropout_grad(&y_grad.to_ref(), &state);
266        let x_grad_cpu = x_grad.to::<Cpu>();
267        let x_grad_cpu_ref = x_grad_cpu.to_ref();
268        let x_grad_cpu_slice = x_grad_cpu_ref.as_slice_unchecked();
269
270        for i in 0..y_cpu_slice.len() {
271            if zero_indexed[i] {
272                assert_eq!(x_grad_cpu_slice[i], 0.0);
273            } else {
274                assert_eq!(x_grad_cpu_slice[i], 1. / (1. - 0.8));
275            }
276        }
277    }
278    run_mat_test!(dropout_4d, dropout_4d_cpu, dropout_4d_gpu);
279
280    // fn dropout_2d<D: Device>() {
281    //     let mut state = DropoutState::<f32, D>::new(0.8);
282    //     let x = crate::matrix::Matrix::from_vec(
283    //         vec![
284    //             1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
285    //         ],
286    //         &[2, 2 * 3],
287    //     );
288    //
289    //     let y = dropout(&x, &mut state);
290    //     let y_cpu = y.clone().to::<Cpu>();
291    //     let y_cpu_ref = y_cpu.to_ref();
292    //     let y_cpu_slice = y_cpu_ref.as_slice();
293    //     let zero_indexed = y_cpu_slice.iter().map(|x| *x == 0.).collect::<Vec<bool>>();
294    //     let y_grad = Matrix::ones_like(&y);
295    //     let x_grad = dropout_grad(&y_grad.to_ref(), &state);
296    //     let x_grad_cpu = x_grad.to::<Cpu>();
297    //     let x_grad_cpu_ref = x_grad_cpu.to_ref();
298    //     let x_grad_cpu_slice = x_grad_cpu_ref.as_slice();
299    //
300    //     for i in 0..y_cpu_slice.len() {
301    //         if zero_indexed[i] {
302    //             assert_eq!(x_grad_cpu_slice[i], 0.0);
303    //         } else {
304    //             assert_eq!(x_grad_cpu_slice[i], 1. / (1. - 0.8));
305    //         }
306    //     }
307    // }
308    // run_mat_test!(dropout_2d, dropout_2d_cpu, dropout_2d_gpu);
309
310    // fn dropout_zeros_raito<D: Device>() {
311    //     let mut state = DropoutState::<f32, D>::new(0.75);
312    //     let x = crate::matrix::Matrix::from_vec(vec![1.0; 200], &[10, 20]);
313    //
314    //     let y = dropout(&x, &mut state);
315    //     let y_cpu = y.clone().to::<Cpu>();
316    //     let y_cpu_ref = y_cpu.to_ref();
317    //     let y_cpu_slice = y_cpu_ref.as_slice();
318    //
319    //     let mut num_zeros = 0;
320    //     for i in 0..y_cpu_slice.len() {
321    //         if y_cpu_slice[i] == 0.0 {
322    //             num_zeros += 1;
323    //         }
324    //     }
325    //
326    //     let raito = num_zeros as f32 / y_cpu_slice.len() as f32;
327    //     let expected = 0.75;
328    //     assert!(dbg!(raito - expected).abs() < 0.01);
329    // }
330    // run_mat_test!(
331    //     dropout_zeros_raito,
332    //     dropout_zeros_raito_cpu,
333    //     dropout_zeros_raito_gpu
334    // );
335}