scirs2_neural/layers/
dropout.rs1use crate::error::{NeuralError, Result};
8use crate::layers::Layer;
9use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
10use scirs2_core::numeric::Float;
11use scirs2_core::random::{Rng, RngCore, SeedableRng};
12use std::fmt::Debug;
13use std::marker::PhantomData;
14use std::sync::{Arc, RwLock};
15
16pub struct Dropout<F: Float + Debug + Send + Sync> {
21 p: F,
23 rng: Arc<RwLock<Box<dyn RngCore + Send + Sync>>>,
25 training: bool,
27 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
29 mask_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
31 _phantom: PhantomData<F>,
33}
34
35impl<F: Float + Debug + Send + Sync> std::fmt::Debug for Dropout<F> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("Dropout")
39 .field("p", &self.p)
40 .field("rng", &"<dyn RngCore>")
41 .field("training", &self.training)
42 .finish()
43 }
44}
45
46impl<F: Float + Debug + Send + Sync> Clone for Dropout<F> {
48 fn clone(&self) -> Self {
49 let rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
50 Self {
51 p: self.p,
52 rng: Arc::new(RwLock::new(Box::new(rng))),
53 training: self.training,
54 input_cache: Arc::new(RwLock::new(None)),
55 mask_cache: Arc::new(RwLock::new(None)),
56 _phantom: PhantomData,
57 }
58 }
59}
60
61impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Dropout<F> {
62 pub fn new<R: Rng + 'static + Clone + Send + Sync>(p: f64, rng: &mut R) -> Result<Self> {
68 if !(0.0..1.0).contains(&p) {
69 return Err(NeuralError::InvalidArchitecture(
70 "Dropout probability must be in [0, 1)".to_string(),
71 ));
72 }
73
74 let p = F::from(p).ok_or_else(|| {
75 NeuralError::InvalidArchitecture(
76 "Failed to convert dropout probability to type F".to_string(),
77 )
78 })?;
79
80 Ok(Self {
81 p,
82 rng: Arc::new(RwLock::new(Box::new(rng.clone()))),
83 training: true,
84 input_cache: Arc::new(RwLock::new(None)),
85 mask_cache: Arc::new(RwLock::new(None)),
86 _phantom: PhantomData,
87 })
88 }
89
90 pub fn set_training(&mut self, training: bool) {
94 self.training = training;
95 }
96
97 pub fn p(&self) -> f64 {
99 self.p.to_f64().unwrap_or(0.0)
100 }
101
102 pub fn is_training(&self) -> bool {
104 self.training
105 }
106}
107
108impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dropout<F> {
109 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
110 if let Ok(mut cache) = self.input_cache.write() {
112 *cache = Some(input.clone());
113 } else {
114 return Err(NeuralError::InferenceError(
115 "Failed to acquire write lock on input cache".to_string(),
116 ));
117 }
118
119 if !self.training || self.p == F::zero() {
120 return Ok(input.clone());
122 }
123
124 let mut mask = Array::<F, IxDyn>::from_elem(input.dim(), F::one());
126 let one = F::one();
127 let zero = F::zero();
128
129 {
131 let mut rng_guard = match self.rng.write() {
132 Ok(guard) => guard,
133 Err(_) => {
134 return Err(NeuralError::InferenceError(
135 "Failed to acquire write lock on RNG".to_string(),
136 ))
137 }
138 };
139
140 for elem in mask.iter_mut() {
141 if F::from((**rng_guard).random::<f64>()).unwrap() < self.p {
142 *elem = zero;
143 }
144 }
145 }
146
147 let scale = one / (one - self.p);
149
150 if let Ok(mut cache) = self.mask_cache.write() {
152 *cache = Some(mask.clone());
153 } else {
154 return Err(NeuralError::InferenceError(
155 "Failed to acquire write lock on mask cache".to_string(),
156 ));
157 }
158
159 let output = input * &mask * scale;
161 Ok(output)
162 }
163
164 fn backward(
165 &self,
166 _input: &Array<F, IxDyn>,
167 grad_output: &Array<F, IxDyn>,
168 ) -> Result<Array<F, IxDyn>> {
169 if !self.training {
170 return Ok(grad_output.clone());
172 }
173
174 let mask = {
176 let cache = match self.mask_cache.read() {
177 Ok(cache) => cache,
178 Err(_) => {
179 return Err(NeuralError::InferenceError(
180 "Failed to acquire read lock on mask cache".to_string(),
181 ))
182 }
183 };
184
185 match cache.as_ref() {
186 Some(mask) => mask.clone(),
187 None => {
188 return Err(NeuralError::InferenceError(
189 "No cached mask for backward pass".to_string(),
190 ))
191 }
192 }
193 };
194
195 let one = F::one();
197 let scale = one / (one - self.p);
198
199 let grad_input = grad_output * &mask * scale;
201 Ok(grad_input)
202 }
203
204 fn update(&mut self, _learningrate: F) -> Result<()> {
205 Ok(())
207 }
208
209 fn as_any(&self) -> &dyn std::any::Any {
210 self
211 }
212
213 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
214 self
215 }
216
217 fn set_training(&mut self, training: bool) {
218 self.training = training;
219 }
220
221 fn is_training(&self) -> bool {
222 self.training
223 }
224
225 fn layer_type(&self) -> &str {
226 "Dropout"
227 }
228
229 fn parameter_count(&self) -> usize {
230 0 }
232
233 fn layer_description(&self) -> String {
234 format!("type:Dropout, p:{:.3}", self.p())
235 }
236}
237
238unsafe impl<F: Float + Debug + Send + Sync> Send for Dropout<F> {}
240unsafe impl<F: Float + Debug + Send + Sync> Sync for Dropout<F> {}