1use crate::{TrainError, TrainResult};
55use scirs2_core::ndarray::{Array2, ArrayView2};
56use scirs2_core::random::{RngExt, StdRng};
57
58#[derive(Debug, Clone)]
64pub struct DropBlock {
65 pub block_size: usize,
67
68 pub drop_prob: f64,
70
71 keep_prob: f64,
73}
74
75impl DropBlock {
76 pub fn new(block_size: usize, drop_prob: f64) -> TrainResult<Self> {
92 if block_size == 0 {
93 return Err(TrainError::InvalidParameter(
94 "block_size must be at least 1".to_string(),
95 ));
96 }
97
98 if block_size.is_multiple_of(2) {
99 return Err(TrainError::InvalidParameter(
100 "block_size must be odd".to_string(),
101 ));
102 }
103
104 if !(0.0..=1.0).contains(&drop_prob) {
105 return Err(TrainError::InvalidParameter(
106 "drop_prob must be between 0.0 and 1.0".to_string(),
107 ));
108 }
109
110 Ok(Self {
111 block_size,
112 drop_prob,
113 keep_prob: 1.0 - drop_prob,
114 })
115 }
116
117 pub fn set_drop_prob(&mut self, drop_prob: f64) -> TrainResult<()> {
122 if !(0.0..=1.0).contains(&drop_prob) {
123 return Err(TrainError::InvalidParameter(
124 "drop_prob must be between 0.0 and 1.0".to_string(),
125 ));
126 }
127
128 self.drop_prob = drop_prob;
129 self.keep_prob = 1.0 - drop_prob;
130 Ok(())
131 }
132
133 pub fn apply(
149 &self,
150 activations: &ArrayView2<f64>,
151 training: bool,
152 rng: &mut StdRng,
153 ) -> TrainResult<Array2<f64>> {
154 if !training || self.drop_prob == 0.0 {
155 return Ok(activations.to_owned());
156 }
157
158 let (height, width) = activations.dim();
159
160 if height < self.block_size || width < self.block_size {
161 return Err(TrainError::InvalidParameter(format!(
162 "Activation map size ({}x{}) is smaller than block_size ({})",
163 height, width, self.block_size
164 )));
165 }
166
167 let gamma = self.drop_prob * (height * width) as f64
170 / ((height - self.block_size + 1) * (width - self.block_size + 1)) as f64
171 / (self.block_size * self.block_size) as f64;
172
173 let mut mask = Array2::ones((height, width));
175 let half_block = self.block_size / 2;
176
177 for i in 0..height {
178 for j in 0..width {
179 if rng.random::<f64>() < gamma {
180 let i_start = i.saturating_sub(half_block);
182 let i_end = (i + half_block + 1).min(height);
183 let j_start = j.saturating_sub(half_block);
184 let j_end = (j + half_block + 1).min(width);
185
186 for ii in i_start..i_end {
187 for jj in j_start..j_end {
188 mask[[ii, jj]] = 0.0;
189 }
190 }
191 }
192 }
193 }
194
195 let mut output = activations.to_owned();
197 let count_kept = mask.iter().filter(|&&x| x == 1.0).count();
198 let normalization_factor = if count_kept > 0 {
199 (height * width) as f64 / count_kept as f64
200 } else {
201 1.0
202 };
203
204 for i in 0..height {
205 for j in 0..width {
206 output[[i, j]] *= mask[[i, j]] * normalization_factor;
207 }
208 }
209
210 Ok(output)
211 }
212}
213
214#[derive(Debug, Clone)]
219pub struct LinearDropBlockScheduler {
220 pub drop_prob_target: f64,
222
223 pub total_steps: usize,
225}
226
227impl LinearDropBlockScheduler {
228 pub fn new(drop_prob_target: f64, total_steps: usize) -> TrainResult<Self> {
234 if !(0.0..=1.0).contains(&drop_prob_target) {
235 return Err(TrainError::InvalidParameter(
236 "drop_prob_target must be between 0.0 and 1.0".to_string(),
237 ));
238 }
239
240 if total_steps == 0 {
241 return Err(TrainError::InvalidParameter(
242 "total_steps must be at least 1".to_string(),
243 ));
244 }
245
246 Ok(Self {
247 drop_prob_target,
248 total_steps,
249 })
250 }
251
252 pub fn get_drop_prob(&self, current_step: usize) -> f64 {
260 if current_step >= self.total_steps {
261 return self.drop_prob_target;
262 }
263
264 let progress = current_step as f64 / self.total_steps as f64;
265 self.drop_prob_target * progress
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use scirs2_core::random::SeedableRng;
273
274 #[test]
275 fn test_dropblock_creation() {
276 let db = DropBlock::new(7, 0.1).expect("unwrap");
277 assert_eq!(db.block_size, 7);
278 assert_eq!(db.drop_prob, 0.1);
279 assert_eq!(db.keep_prob, 0.9);
280 }
281
282 #[test]
283 fn test_dropblock_invalid_params() {
284 assert!(DropBlock::new(0, 0.1).is_err());
286
287 assert!(DropBlock::new(4, 0.1).is_err());
289
290 assert!(DropBlock::new(7, -0.1).is_err());
292 assert!(DropBlock::new(7, 1.5).is_err());
293 }
294
295 #[test]
296 fn test_dropblock_set_drop_prob() {
297 let mut db = DropBlock::new(7, 0.1).expect("unwrap");
298
299 db.set_drop_prob(0.2).expect("unwrap");
300 assert_eq!(db.drop_prob, 0.2);
301 assert_eq!(db.keep_prob, 0.8);
302
303 assert!(db.set_drop_prob(1.5).is_err());
305 }
306
307 #[test]
308 fn test_dropblock_inference_mode() {
309 let db = DropBlock::new(3, 0.5).expect("unwrap");
310 let mut rng = StdRng::seed_from_u64(42);
311
312 let activations = Array2::ones((10, 10));
313 let output = db
314 .apply(&activations.view(), false, &mut rng)
315 .expect("unwrap");
316
317 assert_eq!(output, activations);
319 }
320
321 #[test]
322 fn test_dropblock_zero_prob() {
323 let db = DropBlock::new(3, 0.0).expect("unwrap");
324 let mut rng = StdRng::seed_from_u64(42);
325
326 let activations = Array2::ones((10, 10));
327 let output = db
328 .apply(&activations.view(), true, &mut rng)
329 .expect("unwrap");
330
331 assert_eq!(output, activations);
333 }
334
335 #[test]
336 fn test_dropblock_training_mode() {
337 let db = DropBlock::new(3, 0.3).expect("unwrap");
338 let mut rng = StdRng::seed_from_u64(42);
339
340 let activations = Array2::ones((20, 20));
341 let output = db
342 .apply(&activations.view(), true, &mut rng)
343 .expect("unwrap");
344
345 assert_eq!(output.shape(), activations.shape());
347
348 let zeros_count = output.iter().filter(|&&x| x == 0.0).count();
350 assert!(zeros_count > 0, "Expected some blocks to be dropped");
351
352 assert!(zeros_count < 400, "Not all values should be dropped");
354 }
355
356 #[test]
357 fn test_dropblock_small_activation_map() {
358 let db = DropBlock::new(7, 0.1).expect("unwrap");
359 let mut rng = StdRng::seed_from_u64(42);
360
361 let activations = Array2::ones((5, 5));
363 let result = db.apply(&activations.view(), true, &mut rng);
364
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_linear_scheduler_creation() {
370 let scheduler = LinearDropBlockScheduler::new(0.1, 1000).expect("unwrap");
371 assert_eq!(scheduler.drop_prob_target, 0.1);
372 assert_eq!(scheduler.total_steps, 1000);
373 }
374
375 #[test]
376 fn test_linear_scheduler_invalid_params() {
377 assert!(LinearDropBlockScheduler::new(-0.1, 1000).is_err());
379 assert!(LinearDropBlockScheduler::new(1.5, 1000).is_err());
380
381 assert!(LinearDropBlockScheduler::new(0.1, 0).is_err());
383 }
384
385 #[test]
386 fn test_linear_scheduler_interpolation() {
387 let scheduler = LinearDropBlockScheduler::new(0.1, 100).expect("unwrap");
388
389 assert_eq!(scheduler.get_drop_prob(0), 0.0);
391
392 let mid_prob = scheduler.get_drop_prob(50);
394 assert!((mid_prob - 0.05).abs() < 1e-10);
395
396 assert_eq!(scheduler.get_drop_prob(100), 0.1);
398
399 assert_eq!(scheduler.get_drop_prob(150), 0.1);
401 }
402
403 #[test]
404 fn test_dropblock_with_scheduler() {
405 let mut db = DropBlock::new(3, 0.0).expect("unwrap");
406 let scheduler = LinearDropBlockScheduler::new(0.2, 100).expect("unwrap");
407 let mut rng = StdRng::seed_from_u64(42);
408
409 let activations = Array2::ones((20, 20));
410
411 for step in [0, 50, 100] {
413 let drop_prob = scheduler.get_drop_prob(step);
414 db.set_drop_prob(drop_prob).expect("unwrap");
415
416 let output = db
417 .apply(&activations.view(), true, &mut rng)
418 .expect("unwrap");
419 assert_eq!(output.shape(), activations.shape());
420 }
421 }
422
423 #[test]
424 fn test_dropblock_normalization() {
425 let db = DropBlock::new(3, 0.1).expect("unwrap");
426 let mut rng = StdRng::seed_from_u64(42);
427
428 let activations = Array2::from_elem((20, 20), 1.0);
429 let output = db
430 .apply(&activations.view(), true, &mut rng)
431 .expect("unwrap");
432
433 let input_sum = activations.sum();
436 let output_sum = output.sum();
437
438 let relative_diff = (output_sum - input_sum).abs() / input_sum;
440 assert!(
441 relative_diff < 0.5,
442 "Normalization should preserve approximate expected value"
443 );
444 }
445}