Skip to main content

tensorlogic_train/
dropblock.rs

1//! DropBlock regularization for convolutional networks.
2//!
3//! Implements DropBlock, a structured form of dropout that drops contiguous regions
4//! rather than independent random units. This is particularly effective for convolutional
5//! neural networks where spatial correlation means standard dropout is less effective.
6//!
7//! # References
8//!
9//! - Ghiasi, G., Lin, T. Y., & Le, Q. V. (2018).
10//!   "DropBlock: A regularization method for convolutional networks". NeurIPS 2018.
11//!   <https://arxiv.org/abs/1810.12890>
12//!
13//! - Used in:
14//!   - ResNets (ImageNet)
15//!   - AmoebaNet
16//!   - EfficientNet variants
17//!   - Modern CNNs in general
18//!
19//! # Key Concepts
20//!
21//! **DropBlock vs Dropout**:
22//! - Dropout: Randomly zeros individual units/pixels
23//! - DropBlock: Randomly zeros contiguous blocks/regions
24//!
25//! **Why DropBlock works better for CNNs**:
26//! - Convolutional layers have spatial correlation
27//! - Dropping individual pixels allows network to use nearby activations
28//! - Dropping blocks forces network to learn more robust features
29//!
30//! **Block size**: Typically 7x7 or 5x5 for images
31//! **Drop probability**: Should be scheduled (linear increase during training)
32//!
33//! # Example
34//!
35//! ```rust
36//! use tensorlogic_train::DropBlock;
37//! use scirs2_core::ndarray::Array2;
38//! use scirs2_core::random::{StdRng, SeedableRng};
39//!
40//! // Create DropBlock with block_size=3, drop_prob=0.1
41//! let dropblock = DropBlock::new(3, 0.1).expect("unwrap");
42//!
43//! let mut rng = StdRng::seed_from_u64(42);
44//! let activations = Array2::ones((10, 10));
45//!
46//! // Training: drop blocks
47//! let dropped = dropblock.apply(&activations.view(), true, &mut rng).expect("unwrap");
48//!
49//! // Inference: no dropping
50//! let output = dropblock.apply(&activations.view(), false, &mut rng).expect("unwrap");
51//! assert_eq!(output, activations);
52//! ```
53
54use crate::{TrainError, TrainResult};
55use scirs2_core::ndarray::{Array2, ArrayView2};
56use scirs2_core::random::{RngExt, StdRng};
57
58/// DropBlock regularization.
59///
60/// Drops contiguous blocks of activations in convolutional feature maps.
61/// This is more effective than standard dropout for CNNs because it forces
62/// the network to learn more distributed representations.
63#[derive(Debug, Clone)]
64pub struct DropBlock {
65    /// Size of the block to drop (e.g., 7 for 7x7 blocks)
66    pub block_size: usize,
67
68    /// Probability that a block center will be chosen for dropping
69    pub drop_prob: f64,
70
71    /// Keep probability (1 - drop_prob)
72    keep_prob: f64,
73}
74
75impl DropBlock {
76    /// Create a new DropBlock regularizer.
77    ///
78    /// # Arguments
79    /// * `block_size` - Size of the block to drop (must be odd and >= 1)
80    /// * `drop_prob` - Probability of dropping a block (0.0 to 1.0)
81    ///
82    /// # Returns
83    /// A new DropBlock instance or an error if parameters are invalid.
84    ///
85    /// # Example
86    /// ```rust
87    /// use tensorlogic_train::DropBlock;
88    ///
89    /// let dropblock = DropBlock::new(7, 0.1).expect("unwrap");
90    /// ```
91    pub fn new(block_size: usize, drop_prob: f64) -> TrainResult<Self> {
92        if block_size == 0 {
93            return Err(TrainError::InvalidParameter(
94                "block_size must be at least 1".to_string(),
95            ));
96        }
97
98        if block_size.is_multiple_of(2) {
99            return Err(TrainError::InvalidParameter(
100                "block_size must be odd".to_string(),
101            ));
102        }
103
104        if !(0.0..=1.0).contains(&drop_prob) {
105            return Err(TrainError::InvalidParameter(
106                "drop_prob must be between 0.0 and 1.0".to_string(),
107            ));
108        }
109
110        Ok(Self {
111            block_size,
112            drop_prob,
113            keep_prob: 1.0 - drop_prob,
114        })
115    }
116
117    /// Set the drop probability (useful for scheduling).
118    ///
119    /// # Arguments
120    /// * `drop_prob` - New drop probability (0.0 to 1.0)
121    pub fn set_drop_prob(&mut self, drop_prob: f64) -> TrainResult<()> {
122        if !(0.0..=1.0).contains(&drop_prob) {
123            return Err(TrainError::InvalidParameter(
124                "drop_prob must be between 0.0 and 1.0".to_string(),
125            ));
126        }
127
128        self.drop_prob = drop_prob;
129        self.keep_prob = 1.0 - drop_prob;
130        Ok(())
131    }
132
133    /// Apply DropBlock to activations.
134    ///
135    /// # Arguments
136    /// * `activations` - Input activation map (height × width)
137    /// * `training` - Whether in training mode (drops blocks) or inference mode (no dropping)
138    /// * `rng` - Random number generator
139    ///
140    /// # Returns
141    /// Activation map with blocks dropped (if training) or unchanged (if inference)
142    ///
143    /// # Algorithm
144    /// 1. Sample Bernoulli mask for each position (potential block centers)
145    /// 2. Expand each selected center to a block of size block_size × block_size
146    /// 3. Zero out the selected blocks
147    /// 4. Normalize by keep_prob to maintain expected value
148    pub fn apply(
149        &self,
150        activations: &ArrayView2<f64>,
151        training: bool,
152        rng: &mut StdRng,
153    ) -> TrainResult<Array2<f64>> {
154        if !training || self.drop_prob == 0.0 {
155            return Ok(activations.to_owned());
156        }
157
158        let (height, width) = activations.dim();
159
160        if height < self.block_size || width < self.block_size {
161            return Err(TrainError::InvalidParameter(format!(
162                "Activation map size ({}x{}) is smaller than block_size ({})",
163                height, width, self.block_size
164            )));
165        }
166
167        // Compute gamma (adjusted drop probability accounting for block size)
168        // This ensures the expected fraction of dropped units matches drop_prob
169        let gamma = self.drop_prob * (height * width) as f64
170            / ((height - self.block_size + 1) * (width - self.block_size + 1)) as f64
171            / (self.block_size * self.block_size) as f64;
172
173        // Sample block centers using Bernoulli(gamma)
174        let mut mask = Array2::ones((height, width));
175        let half_block = self.block_size / 2;
176
177        for i in 0..height {
178            for j in 0..width {
179                if rng.random::<f64>() < gamma {
180                    // This position is a block center - zero out the block
181                    let i_start = i.saturating_sub(half_block);
182                    let i_end = (i + half_block + 1).min(height);
183                    let j_start = j.saturating_sub(half_block);
184                    let j_end = (j + half_block + 1).min(width);
185
186                    for ii in i_start..i_end {
187                        for jj in j_start..j_end {
188                            mask[[ii, jj]] = 0.0;
189                        }
190                    }
191                }
192            }
193        }
194
195        // Apply mask and normalize
196        let mut output = activations.to_owned();
197        let count_kept = mask.iter().filter(|&&x| x == 1.0).count();
198        let normalization_factor = if count_kept > 0 {
199            (height * width) as f64 / count_kept as f64
200        } else {
201            1.0
202        };
203
204        for i in 0..height {
205            for j in 0..width {
206                output[[i, j]] *= mask[[i, j]] * normalization_factor;
207            }
208        }
209
210        Ok(output)
211    }
212}
213
214/// Linear DropBlock scheduler.
215///
216/// Linearly increases drop probability from 0 to target value over training.
217/// This is the recommended scheduling strategy from the paper.
218#[derive(Debug, Clone)]
219pub struct LinearDropBlockScheduler {
220    /// Target (maximum) drop probability
221    pub drop_prob_target: f64,
222
223    /// Total number of steps to reach target
224    pub total_steps: usize,
225}
226
227impl LinearDropBlockScheduler {
228    /// Create a new linear scheduler.
229    ///
230    /// # Arguments
231    /// * `drop_prob_target` - Final drop probability to reach
232    /// * `total_steps` - Number of training steps to linearly increase over
233    pub fn new(drop_prob_target: f64, total_steps: usize) -> TrainResult<Self> {
234        if !(0.0..=1.0).contains(&drop_prob_target) {
235            return Err(TrainError::InvalidParameter(
236                "drop_prob_target must be between 0.0 and 1.0".to_string(),
237            ));
238        }
239
240        if total_steps == 0 {
241            return Err(TrainError::InvalidParameter(
242                "total_steps must be at least 1".to_string(),
243            ));
244        }
245
246        Ok(Self {
247            drop_prob_target,
248            total_steps,
249        })
250    }
251
252    /// Get drop probability for current step.
253    ///
254    /// # Arguments
255    /// * `current_step` - Current training step (0-indexed)
256    ///
257    /// # Returns
258    /// Drop probability linearly interpolated from 0 to target
259    pub fn get_drop_prob(&self, current_step: usize) -> f64 {
260        if current_step >= self.total_steps {
261            return self.drop_prob_target;
262        }
263
264        let progress = current_step as f64 / self.total_steps as f64;
265        self.drop_prob_target * progress
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use scirs2_core::random::SeedableRng;
273
274    #[test]
275    fn test_dropblock_creation() {
276        let db = DropBlock::new(7, 0.1).expect("unwrap");
277        assert_eq!(db.block_size, 7);
278        assert_eq!(db.drop_prob, 0.1);
279        assert_eq!(db.keep_prob, 0.9);
280    }
281
282    #[test]
283    fn test_dropblock_invalid_params() {
284        // Zero block size
285        assert!(DropBlock::new(0, 0.1).is_err());
286
287        // Even block size
288        assert!(DropBlock::new(4, 0.1).is_err());
289
290        // Invalid drop probability
291        assert!(DropBlock::new(7, -0.1).is_err());
292        assert!(DropBlock::new(7, 1.5).is_err());
293    }
294
295    #[test]
296    fn test_dropblock_set_drop_prob() {
297        let mut db = DropBlock::new(7, 0.1).expect("unwrap");
298
299        db.set_drop_prob(0.2).expect("unwrap");
300        assert_eq!(db.drop_prob, 0.2);
301        assert_eq!(db.keep_prob, 0.8);
302
303        // Invalid probability
304        assert!(db.set_drop_prob(1.5).is_err());
305    }
306
307    #[test]
308    fn test_dropblock_inference_mode() {
309        let db = DropBlock::new(3, 0.5).expect("unwrap");
310        let mut rng = StdRng::seed_from_u64(42);
311
312        let activations = Array2::ones((10, 10));
313        let output = db
314            .apply(&activations.view(), false, &mut rng)
315            .expect("unwrap");
316
317        // In inference mode, output should be unchanged
318        assert_eq!(output, activations);
319    }
320
321    #[test]
322    fn test_dropblock_zero_prob() {
323        let db = DropBlock::new(3, 0.0).expect("unwrap");
324        let mut rng = StdRng::seed_from_u64(42);
325
326        let activations = Array2::ones((10, 10));
327        let output = db
328            .apply(&activations.view(), true, &mut rng)
329            .expect("unwrap");
330
331        // With zero probability, no blocks should be dropped
332        assert_eq!(output, activations);
333    }
334
335    #[test]
336    fn test_dropblock_training_mode() {
337        let db = DropBlock::new(3, 0.3).expect("unwrap");
338        let mut rng = StdRng::seed_from_u64(42);
339
340        let activations = Array2::ones((20, 20));
341        let output = db
342            .apply(&activations.view(), true, &mut rng)
343            .expect("unwrap");
344
345        // Shape should be preserved
346        assert_eq!(output.shape(), activations.shape());
347
348        // Some values should be zero (dropped blocks)
349        let zeros_count = output.iter().filter(|&&x| x == 0.0).count();
350        assert!(zeros_count > 0, "Expected some blocks to be dropped");
351
352        // Not all values should be zero
353        assert!(zeros_count < 400, "Not all values should be dropped");
354    }
355
356    #[test]
357    fn test_dropblock_small_activation_map() {
358        let db = DropBlock::new(7, 0.1).expect("unwrap");
359        let mut rng = StdRng::seed_from_u64(42);
360
361        // Activation map smaller than block size
362        let activations = Array2::ones((5, 5));
363        let result = db.apply(&activations.view(), true, &mut rng);
364
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_linear_scheduler_creation() {
370        let scheduler = LinearDropBlockScheduler::new(0.1, 1000).expect("unwrap");
371        assert_eq!(scheduler.drop_prob_target, 0.1);
372        assert_eq!(scheduler.total_steps, 1000);
373    }
374
375    #[test]
376    fn test_linear_scheduler_invalid_params() {
377        // Invalid target probability
378        assert!(LinearDropBlockScheduler::new(-0.1, 1000).is_err());
379        assert!(LinearDropBlockScheduler::new(1.5, 1000).is_err());
380
381        // Zero steps
382        assert!(LinearDropBlockScheduler::new(0.1, 0).is_err());
383    }
384
385    #[test]
386    fn test_linear_scheduler_interpolation() {
387        let scheduler = LinearDropBlockScheduler::new(0.1, 100).expect("unwrap");
388
389        // At step 0
390        assert_eq!(scheduler.get_drop_prob(0), 0.0);
391
392        // At step 50 (halfway)
393        let mid_prob = scheduler.get_drop_prob(50);
394        assert!((mid_prob - 0.05).abs() < 1e-10);
395
396        // At step 100 (end)
397        assert_eq!(scheduler.get_drop_prob(100), 0.1);
398
399        // Beyond total steps
400        assert_eq!(scheduler.get_drop_prob(150), 0.1);
401    }
402
403    #[test]
404    fn test_dropblock_with_scheduler() {
405        let mut db = DropBlock::new(3, 0.0).expect("unwrap");
406        let scheduler = LinearDropBlockScheduler::new(0.2, 100).expect("unwrap");
407        let mut rng = StdRng::seed_from_u64(42);
408
409        let activations = Array2::ones((20, 20));
410
411        // Simulate training with scheduler
412        for step in [0, 50, 100] {
413            let drop_prob = scheduler.get_drop_prob(step);
414            db.set_drop_prob(drop_prob).expect("unwrap");
415
416            let output = db
417                .apply(&activations.view(), true, &mut rng)
418                .expect("unwrap");
419            assert_eq!(output.shape(), activations.shape());
420        }
421    }
422
423    #[test]
424    fn test_dropblock_normalization() {
425        let db = DropBlock::new(3, 0.1).expect("unwrap");
426        let mut rng = StdRng::seed_from_u64(42);
427
428        let activations = Array2::from_elem((20, 20), 1.0);
429        let output = db
430            .apply(&activations.view(), true, &mut rng)
431            .expect("unwrap");
432
433        // The sum of output should be close to sum of input (due to normalization)
434        // This ensures expected value is preserved
435        let input_sum = activations.sum();
436        let output_sum = output.sum();
437
438        // Allow some variance due to randomness, but should be reasonably close
439        let relative_diff = (output_sum - input_sum).abs() / input_sum;
440        assert!(
441            relative_diff < 0.5,
442            "Normalization should preserve approximate expected value"
443        );
444    }
445}