1use crate::{TrainError, TrainResult};
55use scirs2_core::ndarray::{Array2, ArrayView2};
56use scirs2_core::random::{Rng, StdRng};
57
58#[derive(Debug, Clone)]
64pub struct DropBlock {
65 pub block_size: usize,
67
68 pub drop_prob: f64,
70
71 keep_prob: f64,
73}
74
75impl DropBlock {
76 pub fn new(block_size: usize, drop_prob: f64) -> TrainResult<Self> {
92 if block_size == 0 {
93 return Err(TrainError::InvalidParameter(
94 "block_size must be at least 1".to_string(),
95 ));
96 }
97
98 if block_size.is_multiple_of(2) {
99 return Err(TrainError::InvalidParameter(
100 "block_size must be odd".to_string(),
101 ));
102 }
103
104 if !(0.0..=1.0).contains(&drop_prob) {
105 return Err(TrainError::InvalidParameter(
106 "drop_prob must be between 0.0 and 1.0".to_string(),
107 ));
108 }
109
110 Ok(Self {
111 block_size,
112 drop_prob,
113 keep_prob: 1.0 - drop_prob,
114 })
115 }
116
117 pub fn set_drop_prob(&mut self, drop_prob: f64) -> TrainResult<()> {
122 if !(0.0..=1.0).contains(&drop_prob) {
123 return Err(TrainError::InvalidParameter(
124 "drop_prob must be between 0.0 and 1.0".to_string(),
125 ));
126 }
127
128 self.drop_prob = drop_prob;
129 self.keep_prob = 1.0 - drop_prob;
130 Ok(())
131 }
132
133 pub fn apply(
149 &self,
150 activations: &ArrayView2<f64>,
151 training: bool,
152 rng: &mut StdRng,
153 ) -> TrainResult<Array2<f64>> {
154 if !training || self.drop_prob == 0.0 {
155 return Ok(activations.to_owned());
156 }
157
158 let (height, width) = activations.dim();
159
160 if height < self.block_size || width < self.block_size {
161 return Err(TrainError::InvalidParameter(format!(
162 "Activation map size ({}x{}) is smaller than block_size ({})",
163 height, width, self.block_size
164 )));
165 }
166
167 let gamma = self.drop_prob * (height * width) as f64
170 / ((height - self.block_size + 1) * (width - self.block_size + 1)) as f64
171 / (self.block_size * self.block_size) as f64;
172
173 let mut mask = Array2::ones((height, width));
175 let half_block = self.block_size / 2;
176
177 for i in 0..height {
178 for j in 0..width {
179 if rng.random::<f64>() < gamma {
180 let i_start = i.saturating_sub(half_block);
182 let i_end = (i + half_block + 1).min(height);
183 let j_start = j.saturating_sub(half_block);
184 let j_end = (j + half_block + 1).min(width);
185
186 for ii in i_start..i_end {
187 for jj in j_start..j_end {
188 mask[[ii, jj]] = 0.0;
189 }
190 }
191 }
192 }
193 }
194
195 let mut output = activations.to_owned();
197 let count_kept = mask.iter().filter(|&&x| x == 1.0).count();
198 let normalization_factor = if count_kept > 0 {
199 (height * width) as f64 / count_kept as f64
200 } else {
201 1.0
202 };
203
204 for i in 0..height {
205 for j in 0..width {
206 output[[i, j]] *= mask[[i, j]] * normalization_factor;
207 }
208 }
209
210 Ok(output)
211 }
212}
213
214#[derive(Debug, Clone)]
219pub struct LinearDropBlockScheduler {
220 pub drop_prob_target: f64,
222
223 pub total_steps: usize,
225}
226
227impl LinearDropBlockScheduler {
228 pub fn new(drop_prob_target: f64, total_steps: usize) -> TrainResult<Self> {
234 if !(0.0..=1.0).contains(&drop_prob_target) {
235 return Err(TrainError::InvalidParameter(
236 "drop_prob_target must be between 0.0 and 1.0".to_string(),
237 ));
238 }
239
240 if total_steps == 0 {
241 return Err(TrainError::InvalidParameter(
242 "total_steps must be at least 1".to_string(),
243 ));
244 }
245
246 Ok(Self {
247 drop_prob_target,
248 total_steps,
249 })
250 }
251
252 pub fn get_drop_prob(&self, current_step: usize) -> f64 {
260 if current_step >= self.total_steps {
261 return self.drop_prob_target;
262 }
263
264 let progress = current_step as f64 / self.total_steps as f64;
265 self.drop_prob_target * progress
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use scirs2_core::random::SeedableRng;
273
274 #[test]
275 fn test_dropblock_creation() {
276 let db = DropBlock::new(7, 0.1).unwrap();
277 assert_eq!(db.block_size, 7);
278 assert_eq!(db.drop_prob, 0.1);
279 assert_eq!(db.keep_prob, 0.9);
280 }
281
282 #[test]
283 fn test_dropblock_invalid_params() {
284 assert!(DropBlock::new(0, 0.1).is_err());
286
287 assert!(DropBlock::new(4, 0.1).is_err());
289
290 assert!(DropBlock::new(7, -0.1).is_err());
292 assert!(DropBlock::new(7, 1.5).is_err());
293 }
294
295 #[test]
296 fn test_dropblock_set_drop_prob() {
297 let mut db = DropBlock::new(7, 0.1).unwrap();
298
299 db.set_drop_prob(0.2).unwrap();
300 assert_eq!(db.drop_prob, 0.2);
301 assert_eq!(db.keep_prob, 0.8);
302
303 assert!(db.set_drop_prob(1.5).is_err());
305 }
306
307 #[test]
308 fn test_dropblock_inference_mode() {
309 let db = DropBlock::new(3, 0.5).unwrap();
310 let mut rng = StdRng::seed_from_u64(42);
311
312 let activations = Array2::ones((10, 10));
313 let output = db.apply(&activations.view(), false, &mut rng).unwrap();
314
315 assert_eq!(output, activations);
317 }
318
319 #[test]
320 fn test_dropblock_zero_prob() {
321 let db = DropBlock::new(3, 0.0).unwrap();
322 let mut rng = StdRng::seed_from_u64(42);
323
324 let activations = Array2::ones((10, 10));
325 let output = db.apply(&activations.view(), true, &mut rng).unwrap();
326
327 assert_eq!(output, activations);
329 }
330
331 #[test]
332 fn test_dropblock_training_mode() {
333 let db = DropBlock::new(3, 0.3).unwrap();
334 let mut rng = StdRng::seed_from_u64(42);
335
336 let activations = Array2::ones((20, 20));
337 let output = db.apply(&activations.view(), true, &mut rng).unwrap();
338
339 assert_eq!(output.shape(), activations.shape());
341
342 let zeros_count = output.iter().filter(|&&x| x == 0.0).count();
344 assert!(zeros_count > 0, "Expected some blocks to be dropped");
345
346 assert!(zeros_count < 400, "Not all values should be dropped");
348 }
349
350 #[test]
351 fn test_dropblock_small_activation_map() {
352 let db = DropBlock::new(7, 0.1).unwrap();
353 let mut rng = StdRng::seed_from_u64(42);
354
355 let activations = Array2::ones((5, 5));
357 let result = db.apply(&activations.view(), true, &mut rng);
358
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn test_linear_scheduler_creation() {
364 let scheduler = LinearDropBlockScheduler::new(0.1, 1000).unwrap();
365 assert_eq!(scheduler.drop_prob_target, 0.1);
366 assert_eq!(scheduler.total_steps, 1000);
367 }
368
369 #[test]
370 fn test_linear_scheduler_invalid_params() {
371 assert!(LinearDropBlockScheduler::new(-0.1, 1000).is_err());
373 assert!(LinearDropBlockScheduler::new(1.5, 1000).is_err());
374
375 assert!(LinearDropBlockScheduler::new(0.1, 0).is_err());
377 }
378
379 #[test]
380 fn test_linear_scheduler_interpolation() {
381 let scheduler = LinearDropBlockScheduler::new(0.1, 100).unwrap();
382
383 assert_eq!(scheduler.get_drop_prob(0), 0.0);
385
386 let mid_prob = scheduler.get_drop_prob(50);
388 assert!((mid_prob - 0.05).abs() < 1e-10);
389
390 assert_eq!(scheduler.get_drop_prob(100), 0.1);
392
393 assert_eq!(scheduler.get_drop_prob(150), 0.1);
395 }
396
397 #[test]
398 fn test_dropblock_with_scheduler() {
399 let mut db = DropBlock::new(3, 0.0).unwrap();
400 let scheduler = LinearDropBlockScheduler::new(0.2, 100).unwrap();
401 let mut rng = StdRng::seed_from_u64(42);
402
403 let activations = Array2::ones((20, 20));
404
405 for step in [0, 50, 100] {
407 let drop_prob = scheduler.get_drop_prob(step);
408 db.set_drop_prob(drop_prob).unwrap();
409
410 let output = db.apply(&activations.view(), true, &mut rng).unwrap();
411 assert_eq!(output.shape(), activations.shape());
412 }
413 }
414
415 #[test]
416 fn test_dropblock_normalization() {
417 let db = DropBlock::new(3, 0.1).unwrap();
418 let mut rng = StdRng::seed_from_u64(42);
419
420 let activations = Array2::from_elem((20, 20), 1.0);
421 let output = db.apply(&activations.view(), true, &mut rng).unwrap();
422
423 let input_sum = activations.sum();
426 let output_sum = output.sum();
427
428 let relative_diff = (output_sum - input_sum).abs() / input_sum;
430 assert!(
431 relative_diff < 0.5,
432 "Normalization should preserve approximate expected value"
433 );
434 }
435}