scirs2_neural/layers/
dropout.rs

1//! Dropout layer implementation
2//!
3//! This module provides implementation of dropout regularization
4//! for neural networks as described in "Dropout: A Simple Way to Prevent Neural Networks
5//! from Overfitting" by Srivastava et al.
6
7use crate::error::{NeuralError, Result};
8use crate::layers::Layer;
9use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
10use scirs2_core::numeric::Float;
11use scirs2_core::random::{Rng, RngCore, SeedableRng};
12use std::fmt::Debug;
13use std::marker::PhantomData;
14use std::sync::{Arc, RwLock};
15
16/// Dropout layer
17///
18/// During training, randomly sets input elements to zero with probability `p`.
19/// During inference, scales the output by 1/(1-p) to maintain the expected value.
20pub struct Dropout<F: Float + Debug + Send + Sync> {
21    /// Probability of dropping an element
22    p: F,
23    /// Random number generator
24    rng: Arc<RwLock<Box<dyn RngCore + Send + Sync>>>,
25    /// Whether we're in training mode
26    training: bool,
27    /// Input cache for backward pass
28    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
29    /// Mask cache for backward pass (1 for kept elements, 0 for dropped)
30    mask_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
31    /// Phantom data for type parameter
32    _phantom: PhantomData<F>,
33}
34
35// Manual implementation of Debug because dyn RngCore doesn't implement Debug
36impl<F: Float + Debug + Send + Sync> std::fmt::Debug for Dropout<F> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("Dropout")
39            .field("p", &self.p)
40            .field("rng", &"<dyn RngCore>")
41            .field("training", &self.training)
42            .finish()
43    }
44}
45
46// Manual implementation of Clone
47impl<F: Float + Debug + Send + Sync> Clone for Dropout<F> {
48    fn clone(&self) -> Self {
49        let rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
50        Self {
51            p: self.p,
52            rng: Arc::new(RwLock::new(Box::new(rng))),
53            training: self.training,
54            input_cache: Arc::new(RwLock::new(None)),
55            mask_cache: Arc::new(RwLock::new(None)),
56            _phantom: PhantomData,
57        }
58    }
59}
60
61impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Dropout<F> {
62    /// Create a new dropout layer
63    ///
64    /// # Arguments
65    /// * `p` - Dropout probability (0.0 to 1.0)
66    /// * `rng` - Random number generator
67    pub fn new<R: Rng + 'static + Clone + Send + Sync>(p: f64, rng: &mut R) -> Result<Self> {
68        if !(0.0..1.0).contains(&p) {
69            return Err(NeuralError::InvalidArchitecture(
70                "Dropout probability must be in [0, 1)".to_string(),
71            ));
72        }
73
74        let p = F::from(p).ok_or_else(|| {
75            NeuralError::InvalidArchitecture(
76                "Failed to convert dropout probability to type F".to_string(),
77            )
78        })?;
79
80        Ok(Self {
81            p,
82            rng: Arc::new(RwLock::new(Box::new(rng.clone()))),
83            training: true,
84            input_cache: Arc::new(RwLock::new(None)),
85            mask_cache: Arc::new(RwLock::new(None)),
86            _phantom: PhantomData,
87        })
88    }
89
90    /// Set the training mode
91    /// In training mode, elements are randomly dropped.
92    /// In inference mode, all elements are kept but scaled.
93    pub fn set_training(&mut self, training: bool) {
94        self.training = training;
95    }
96
97    /// Get the dropout probability
98    pub fn p(&self) -> f64 {
99        self.p.to_f64().unwrap_or(0.0)
100    }
101
102    /// Get the training mode
103    pub fn is_training(&self) -> bool {
104        self.training
105    }
106}
107
108impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dropout<F> {
109    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
110        // Cache input for backward pass
111        if let Ok(mut cache) = self.input_cache.write() {
112            *cache = Some(input.clone());
113        } else {
114            return Err(NeuralError::InferenceError(
115                "Failed to acquire write lock on input cache".to_string(),
116            ));
117        }
118
119        if !self.training || self.p == F::zero() {
120            // In inference mode or with p=0, just pass through the input as is
121            return Ok(input.clone());
122        }
123
124        // In training mode, create a binary mask and apply it
125        let mut mask = Array::<F, IxDyn>::from_elem(input.dim(), F::one());
126        let one = F::one();
127        let zero = F::zero();
128
129        // Apply the dropout mask
130        {
131            let mut rng_guard = match self.rng.write() {
132                Ok(guard) => guard,
133                Err(_) => {
134                    return Err(NeuralError::InferenceError(
135                        "Failed to acquire write lock on RNG".to_string(),
136                    ))
137                }
138            };
139
140            for elem in mask.iter_mut() {
141                if F::from((**rng_guard).random::<f64>()).unwrap() < self.p {
142                    *elem = zero;
143                }
144            }
145        }
146
147        // Scale by 1/(1-p) to maintain expected value
148        let scale = one / (one - self.p);
149
150        // Cache the mask for backward pass
151        if let Ok(mut cache) = self.mask_cache.write() {
152            *cache = Some(mask.clone());
153        } else {
154            return Err(NeuralError::InferenceError(
155                "Failed to acquire write lock on mask cache".to_string(),
156            ));
157        }
158
159        // Apply mask and scale
160        let output = input * &mask * scale;
161        Ok(output)
162    }
163
164    fn backward(
165        &self,
166        _input: &Array<F, IxDyn>,
167        grad_output: &Array<F, IxDyn>,
168    ) -> Result<Array<F, IxDyn>> {
169        if !self.training {
170            // In inference mode, gradients pass through unchanged
171            return Ok(grad_output.clone());
172        }
173
174        // Get cached mask
175        let mask = {
176            let cache = match self.mask_cache.read() {
177                Ok(cache) => cache,
178                Err(_) => {
179                    return Err(NeuralError::InferenceError(
180                        "Failed to acquire read lock on mask cache".to_string(),
181                    ))
182                }
183            };
184
185            match cache.as_ref() {
186                Some(mask) => mask.clone(),
187                None => {
188                    return Err(NeuralError::InferenceError(
189                        "No cached mask for backward pass".to_string(),
190                    ))
191                }
192            }
193        };
194
195        // Scale factor
196        let one = F::one();
197        let scale = one / (one - self.p);
198
199        // Apply mask and scale to gradients
200        let grad_input = grad_output * &mask * scale;
201        Ok(grad_input)
202    }
203
204    fn update(&mut self, _learningrate: F) -> Result<()> {
205        // Dropout has no parameters to update
206        Ok(())
207    }
208
209    fn as_any(&self) -> &dyn std::any::Any {
210        self
211    }
212
213    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
214        self
215    }
216
217    fn set_training(&mut self, training: bool) {
218        self.training = training;
219    }
220
221    fn is_training(&self) -> bool {
222        self.training
223    }
224
225    fn layer_type(&self) -> &str {
226        "Dropout"
227    }
228
229    fn parameter_count(&self) -> usize {
230        0 // Dropout has no trainable parameters
231    }
232
233    fn layer_description(&self) -> String {
234        format!("type:Dropout, p:{:.3}", self.p())
235    }
236}
237
238// Explicit Send + Sync implementations for Dropout layer
239unsafe impl<F: Float + Debug + Send + Sync> Send for Dropout<F> {}
240unsafe impl<F: Float + Debug + Send + Sync> Sync for Dropout<F> {}