1use crate::{
2 device::{cpu::Cpu, DeviceBase},
3 dim::{DimDyn, DimTrait},
4 matrix::{Matrix, Owned, Ref, Repr},
5 num::Num,
6};
7
8use rand::seq::SliceRandom;
9#[cfg(feature = "nvidia")]
10use zenu_cuda::cudnn::dropout::DropoutConfig;
11
12#[cfg(feature = "nvidia")]
13use crate::device::nvidia::Nvidia;
14
15#[cfg(feature = "nvidia")]
16use super::NNCache;
17
18pub struct DropoutState<T: Num, D: DeviceBase> {
19 pub rate: f32,
20 pub state: Option<Matrix<Owned<T>, DimDyn, Cpu>>,
21 #[cfg(feature = "nvidia")]
22 gpu_state: Option<DropoutConfig<T>>,
23 #[cfg(feature = "nvidia")]
24 state_cache: Option<NNCache<D>>,
25 #[cfg(feature = "nvidia")]
26 reserve_space_cache: Option<NNCache<D>>,
27 _device: std::marker::PhantomData<D>,
28}
29
30impl<T: Num, D: DeviceBase> DropoutState<T, D> {
31 #[must_use]
32 pub fn new(rate: f32) -> Self {
33 Self {
34 rate,
35 state: None,
36 #[cfg(feature = "nvidia")]
37 gpu_state: None,
38 #[cfg(feature = "nvidia")]
39 state_cache: None,
40 #[cfg(feature = "nvidia")]
41 reserve_space_cache: None,
42 _device: std::marker::PhantomData,
43 }
44 }
45
46 #[expect(clippy::missing_panics_doc)]
47 pub fn gpu_init(&mut self, shape: DimDyn) {
48 #[cfg(feature = "nvidia")]
49 {
50 let gpu_state = DropoutConfig::new(shape.slice()).unwrap();
51 let state_size = gpu_state.get_state_size();
52 let reserve_space_size = gpu_state.get_reserve_space_size();
53 let cache = NNCache::<D>::new(state_size);
54 let reserve_space_cache = NNCache::<D>::new(reserve_space_size);
55 gpu_state.set(self.rate, 0, cache.ptr.cast()).unwrap();
56 self.gpu_state = Some(gpu_state);
57 self.state_cache = Some(cache);
58 self.reserve_space_cache = Some(reserve_space_cache);
59 }
60 #[cfg(not(feature = "nvidia"))]
61 {
62 panic!("GPU support is not enabled");
63 }
64 }
65}
66
67#[expect(
68 clippy::cast_precision_loss,
69 clippy::cast_sign_loss,
70 clippy::cast_possible_truncation
71)]
72fn dropout_mask_inner<T>(input: &mut [T], dropout_ratio: f32)
73where
74 T: Copy + std::ops::Mul<T, Output = T> + std::ops::AddAssign + Num,
75{
76 assert!(
77 (0.0..=1.0).contains(&dropout_ratio),
78 "Dropout ratio must be between 0 and 1"
79 );
80
81 let num_zeros_item = (input.len() as f32 * dropout_ratio) as usize;
82 let mut rng = rand::thread_rng();
83
84 let mut indices: Vec<usize> = (0..input.len()).collect();
85 indices.shuffle(&mut rng);
86 for i in 0..num_zeros_item {
87 input[indices[i]] = T::zero();
88 }
89}
90
91#[expect(clippy::needless_pass_by_value)]
92fn dropout_mask<T>(input: Matrix<Ref<&mut T>, DimDyn, Cpu>, dropout_ratio: f32)
93where
94 T: Copy + std::ops::Mul<T, Output = T> + std::ops::AddAssign + Num,
95{
96 let input_slice = input.as_mut_slice();
97 dropout_mask_inner(input_slice, dropout_ratio);
98}
99
100pub trait Dropout: DeviceBase {
101 fn dropout<T: Num>(
102 x: &Matrix<Ref<&T>, DimDyn, Self>,
103 state: &mut DropoutState<T, Self>,
104 ) -> Matrix<Owned<T>, DimDyn, Self>;
105
106 fn dropout_grad<T: Num>(
107 dy: &Matrix<Ref<&T>, DimDyn, Self>,
108 state: &DropoutState<T, Self>,
109 ) -> Matrix<Owned<T>, DimDyn, Self>;
110}
111
112impl Dropout for Cpu {
113 fn dropout<T: Num>(
114 x: &Matrix<Ref<&T>, DimDyn, Self>,
115 state: &mut DropoutState<T, Self>,
116 ) -> Matrix<Owned<T>, DimDyn, Self> {
117 let rate = state.rate;
118 let num_elm = x.shape().num_elm();
119 let mask = {
120 let mut mask = Matrix::ones([num_elm]);
121 dropout_mask(mask.to_ref_mut(), rate);
122 mask
123 };
124 let grad_ratio = T::one() / T::from(1.0 - rate).unwrap();
125 let y = x.reshape([num_elm]) * mask.to_ref() * grad_ratio;
126 state.state = Some(mask);
127 y.reshape_no_alloc_owned(x.shape())
128 }
129
130 fn dropout_grad<T: Num>(
131 dy: &Matrix<Ref<&T>, DimDyn, Self>,
132 state: &DropoutState<T, Self>,
133 ) -> Matrix<Owned<T>, DimDyn, Self> {
134 let rate = state.rate;
135 let mask = state.state.as_ref().unwrap();
136 let dy_original_shape = dy.shape();
137 let dy = dy.reshape([dy.shape().num_elm()]);
138 let grad_ratio = T::one() / T::from(1.0 - rate).unwrap();
139 let dx = dy.to_ref() * mask.to_ref() * grad_ratio;
140 dx.reshape_no_alloc_owned(dy_original_shape)
141 }
142}
143
144#[cfg(feature = "nvidia")]
145impl Dropout for Nvidia {
146 fn dropout<T: Num>(
147 x: &Matrix<Ref<&T>, DimDyn, Self>,
148 state: &mut DropoutState<T, Self>,
149 ) -> Matrix<Owned<T>, DimDyn, Self> {
150 match state.gpu_state {
151 Some(_) => {}
152 None => {
153 state.gpu_init(x.shape());
154 }
155 };
156
157 assert!(
158 state.reserve_space_cache.is_some(),
159 "Reserve space cache is not initialized"
160 );
161
162 let mut y = Matrix::<Owned<T>, _, _>::alloc(x.shape().slice());
163 let space_cache = state.reserve_space_cache.as_ref().unwrap();
164 let spcace_cache_ptr = space_cache.ptr.cast();
165
166 {
167 let y_mut_ref = y.to_ref_mut();
168 state
169 .gpu_state
170 .as_ref()
171 .unwrap()
172 .forward(
173 x.as_ptr().cast(),
174 y_mut_ref.as_mut_ptr().cast(),
175 spcace_cache_ptr,
176 )
177 .unwrap();
178 }
179 y
180 }
181
182 fn dropout_grad<T: Num>(
183 dy: &Matrix<Ref<&T>, DimDyn, Self>,
184 state: &DropoutState<T, Self>,
185 ) -> Matrix<Owned<T>, DimDyn, Self> {
186 let gpu_state = state.gpu_state.as_ref().unwrap();
187 let space_cache = state.reserve_space_cache.as_ref().unwrap();
188
189 let mut dx = Matrix::<Owned<T>, _, _>::alloc(dy.shape().slice());
190
191 {
192 let dx_mut_ref = dx.to_ref_mut();
193 gpu_state
194 .backward(
195 dy.as_ptr().cast(),
196 dx_mut_ref.as_mut_ptr().cast(),
197 space_cache.ptr.cast(),
198 )
199 .unwrap();
200 }
201 dx
202 }
203}
204
205#[expect(clippy::missing_panics_doc)]
206pub fn dropout<R, D>(
207 x: &Matrix<R, DimDyn, D>,
208 state: &mut DropoutState<R::Item, D>,
209) -> Matrix<Owned<R::Item>, DimDyn, D>
210where
211 R: Repr,
212 D: Dropout + DeviceBase,
213{
214 assert!(
215 (x.shape().len() == 2) || (x.shape().len() == 4),
216 "Only 2D and 4D tensors are supported"
217 );
218 D::dropout(&x.to_ref(), state)
219}
220
221#[expect(clippy::missing_panics_doc)]
222#[must_use]
223pub fn dropout_grad<R, D>(
224 dy: &Matrix<R, DimDyn, D>,
225 state: &DropoutState<R::Item, D>,
226) -> Matrix<Owned<R::Item>, DimDyn, D>
227where
228 R: Repr,
229 D: Dropout + DeviceBase,
230{
231 assert!(
232 (dy.shape().len() == 2) || (dy.shape().len() == 4),
233 "Only 2D and 4D tensors are supported"
234 );
235 D::dropout_grad(&dy.to_ref(), state)
236}
237
238#[cfg(test)]
239mod dropout {
240 use zenu_test::run_mat_test;
241
242 use crate::{
243 device::{cpu::Cpu, Device},
244 matrix::Matrix,
245 };
246
247 use super::{dropout, dropout_grad, DropoutState};
248
249 #[expect(clippy::float_cmp)]
250 fn dropout_4d<D: Device>() {
251 let mut state = DropoutState::<f32, D>::new(0.8);
252 let x = crate::matrix::Matrix::from_vec(
253 vec![
254 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
255 ],
256 [1, 2, 2, 3],
257 );
258
259 let y = dropout(&x, &mut state);
260 let y_cpu = y.clone().to::<Cpu>();
261 let y_cpu_ref = y_cpu.to_ref();
262 let y_cpu_slice = y_cpu_ref.as_slice_unchecked();
263 let zero_indexed = y_cpu_slice.iter().map(|x| *x == 0.).collect::<Vec<bool>>();
264 let y_grad = Matrix::ones_like(&y);
265 let x_grad = dropout_grad(&y_grad.to_ref(), &state);
266 let x_grad_cpu = x_grad.to::<Cpu>();
267 let x_grad_cpu_ref = x_grad_cpu.to_ref();
268 let x_grad_cpu_slice = x_grad_cpu_ref.as_slice_unchecked();
269
270 for i in 0..y_cpu_slice.len() {
271 if zero_indexed[i] {
272 assert_eq!(x_grad_cpu_slice[i], 0.0);
273 } else {
274 assert_eq!(x_grad_cpu_slice[i], 1. / (1. - 0.8));
275 }
276 }
277 }
278 run_mat_test!(dropout_4d, dropout_4d_cpu, dropout_4d_gpu);
279
280 }