1use crate::{TrainError, TrainResult};
35use scirs2_core::ndarray::{Array2, ArrayView2};
36use scirs2_core::random::{Rng, StdRng};
37
38#[derive(Debug, Clone)]
62pub struct DropPath {
63 pub drop_prob: f64,
65 keep_prob: f64,
67}
68
69impl DropPath {
70 pub fn new(drop_prob: f64) -> TrainResult<Self> {
80 if !(0.0..=1.0).contains(&drop_prob) {
81 return Err(TrainError::InvalidParameter(
82 "drop_prob must be in [0, 1]".to_string(),
83 ));
84 }
85
86 Ok(Self {
87 drop_prob,
88 keep_prob: 1.0 - drop_prob,
89 })
90 }
91
92 pub fn apply(
104 &self,
105 path: &ArrayView2<f64>,
106 training: bool,
107 rng: &mut StdRng,
108 ) -> TrainResult<Array2<f64>> {
109 if !training || self.drop_prob == 0.0 {
110 return Ok(path.to_owned());
112 }
113
114 if self.drop_prob == 1.0 {
115 return Ok(Array2::zeros(path.raw_dim()));
117 }
118
119 let should_drop = rng.random::<f64>() < self.drop_prob;
121
122 if should_drop {
123 Ok(Array2::zeros(path.raw_dim()))
125 } else {
126 Ok(path.mapv(|x| x / self.keep_prob))
129 }
130 }
131
132 pub fn apply_batch(
146 &self,
147 paths: &ArrayView2<f64>,
148 training: bool,
149 rng: &mut StdRng,
150 ) -> TrainResult<Array2<f64>> {
151 if !training || self.drop_prob == 0.0 {
152 return Ok(paths.to_owned());
153 }
154
155 let (batch_size, _) = paths.dim();
156 let mut output = paths.to_owned();
157
158 for i in 0..batch_size {
160 let should_drop = rng.random::<f64>() < self.drop_prob;
161 if should_drop {
162 for j in 0..output.ncols() {
164 output[[i, j]] = 0.0;
165 }
166 } else {
167 for j in 0..output.ncols() {
169 output[[i, j]] /= self.keep_prob;
170 }
171 }
172 }
173
174 Ok(output)
175 }
176
177 pub fn keep_probability(&self) -> f64 {
179 self.keep_prob
180 }
181
182 pub fn set_drop_prob(&mut self, drop_prob: f64) -> TrainResult<()> {
184 if !(0.0..=1.0).contains(&drop_prob) {
185 return Err(TrainError::InvalidParameter(
186 "drop_prob must be in [0, 1]".to_string(),
187 ));
188 }
189
190 self.drop_prob = drop_prob;
191 self.keep_prob = 1.0 - drop_prob;
192 Ok(())
193 }
194}
195
196#[derive(Debug, Clone)]
216pub struct LinearStochasticDepth {
217 pub num_layers: usize,
219 pub drop_prob_min: f64,
221 pub drop_prob_max: f64,
223}
224
225impl LinearStochasticDepth {
226 pub fn new(num_layers: usize, drop_prob_min: f64, drop_prob_max: f64) -> TrainResult<Self> {
238 if num_layers == 0 {
239 return Err(TrainError::InvalidParameter(
240 "num_layers must be > 0".to_string(),
241 ));
242 }
243
244 if !(0.0..=1.0).contains(&drop_prob_min) || !(0.0..=1.0).contains(&drop_prob_max) {
245 return Err(TrainError::InvalidParameter(
246 "drop probabilities must be in [0, 1]".to_string(),
247 ));
248 }
249
250 if drop_prob_min > drop_prob_max {
251 return Err(TrainError::InvalidParameter(
252 "drop_prob_min must be <= drop_prob_max".to_string(),
253 ));
254 }
255
256 Ok(Self {
257 num_layers,
258 drop_prob_min,
259 drop_prob_max,
260 })
261 }
262
263 pub fn get_drop_prob(&self, layer_idx: usize) -> f64 {
278 if layer_idx >= self.num_layers {
279 return self.drop_prob_max;
280 }
281
282 if self.num_layers == 1 {
283 return self.drop_prob_min;
284 }
285
286 let ratio = layer_idx as f64 / (self.num_layers - 1) as f64;
288 self.drop_prob_min + (self.drop_prob_max - self.drop_prob_min) * ratio
289 }
290
291 pub fn create_drop_paths(&self) -> TrainResult<Vec<DropPath>> {
297 let mut drop_paths = Vec::with_capacity(self.num_layers);
298
299 for i in 0..self.num_layers {
300 let drop_prob = self.get_drop_prob(i);
301 drop_paths.push(DropPath::new(drop_prob)?);
302 }
303
304 Ok(drop_paths)
305 }
306}
307
308#[derive(Debug, Clone)]
313pub struct ExponentialStochasticDepth {
314 pub num_layers: usize,
316 pub drop_prob_min: f64,
318 pub drop_prob_max: f64,
320}
321
322impl ExponentialStochasticDepth {
323 pub fn new(num_layers: usize, drop_prob_min: f64, drop_prob_max: f64) -> TrainResult<Self> {
325 if num_layers == 0 {
326 return Err(TrainError::InvalidParameter(
327 "num_layers must be > 0".to_string(),
328 ));
329 }
330
331 if !(0.0..=1.0).contains(&drop_prob_min) || !(0.0..=1.0).contains(&drop_prob_max) {
332 return Err(TrainError::InvalidParameter(
333 "drop probabilities must be in [0, 1]".to_string(),
334 ));
335 }
336
337 if drop_prob_min > drop_prob_max {
338 return Err(TrainError::InvalidParameter(
339 "drop_prob_min must be <= drop_prob_max".to_string(),
340 ));
341 }
342
343 Ok(Self {
344 num_layers,
345 drop_prob_min,
346 drop_prob_max,
347 })
348 }
349
350 pub fn get_drop_prob(&self, layer_idx: usize) -> f64 {
352 if layer_idx >= self.num_layers {
353 return self.drop_prob_max;
354 }
355
356 if self.num_layers == 1 {
357 return self.drop_prob_min;
358 }
359
360 let ratio = layer_idx as f64 / (self.num_layers - 1) as f64;
362 let exp_ratio = ratio * ratio; self.drop_prob_min + (self.drop_prob_max - self.drop_prob_min) * exp_ratio
365 }
366
367 pub fn create_drop_paths(&self) -> TrainResult<Vec<DropPath>> {
369 let mut drop_paths = Vec::with_capacity(self.num_layers);
370
371 for i in 0..self.num_layers {
372 let drop_prob = self.get_drop_prob(i);
373 drop_paths.push(DropPath::new(drop_prob)?);
374 }
375
376 Ok(drop_paths)
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use scirs2_core::ndarray::array;
384 use scirs2_core::random::SeedableRng;
385
386 fn create_test_rng() -> StdRng {
387 StdRng::seed_from_u64(42)
388 }
389
390 #[test]
391 fn test_drop_path_creation() {
392 let dp = DropPath::new(0.2).unwrap();
393 assert_eq!(dp.drop_prob, 0.2);
394 assert!((dp.keep_prob - 0.8).abs() < 1e-10);
395 }
396
397 #[test]
398 fn test_drop_path_invalid_prob() {
399 assert!(DropPath::new(-0.1).is_err());
400 assert!(DropPath::new(1.5).is_err());
401 }
402
403 #[test]
404 fn test_drop_path_zero_prob() {
405 let dp = DropPath::new(0.0).unwrap();
406 let mut rng = create_test_rng();
407
408 let path = array![[1.0, 2.0], [3.0, 4.0]];
409
410 let output = dp.apply(&path.view(), true, &mut rng).unwrap();
412 assert_eq!(output, path);
413 }
414
415 #[test]
416 fn test_drop_path_full_prob() {
417 let dp = DropPath::new(1.0).unwrap();
418 let mut rng = create_test_rng();
419
420 let path = array![[1.0, 2.0], [3.0, 4.0]];
421
422 let output = dp.apply(&path.view(), true, &mut rng).unwrap();
424 assert_eq!(output, Array2::<f64>::zeros((2, 2)));
425 }
426
427 #[test]
428 fn test_drop_path_inference_mode() {
429 let dp = DropPath::new(0.5).unwrap();
430 let mut rng = create_test_rng();
431
432 let path = array![[1.0, 2.0], [3.0, 4.0]];
433
434 let output = dp.apply(&path.view(), false, &mut rng).unwrap();
436 assert_eq!(output, path);
437 }
438
439 #[test]
440 fn test_drop_path_training_mode() {
441 let dp = DropPath::new(0.5).unwrap();
442 let mut rng = create_test_rng();
443
444 let path = array![[1.0, 2.0]];
445
446 let mut dropped_count = 0;
448 let mut kept_count = 0;
449
450 for _ in 0..100 {
451 let output = dp.apply(&path.view(), true, &mut rng).unwrap();
452
453 if output[[0, 0]] == 0.0 {
454 dropped_count += 1;
455 } else {
456 kept_count += 1;
457 assert!((output[[0, 0]] - 2.0).abs() < 1e-10);
459 }
460 }
461
462 assert!(dropped_count > 30 && dropped_count < 70);
464 assert!(kept_count > 30 && kept_count < 70);
465 }
466
467 #[test]
468 fn test_drop_path_batch() {
469 let dp = DropPath::new(0.5).unwrap();
470 let mut rng = create_test_rng();
471
472 let paths = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
473
474 let output = dp.apply_batch(&paths.view(), true, &mut rng).unwrap();
475
476 assert_eq!(output.shape(), paths.shape());
478
479 let mut dropped_rows = 0;
481 for i in 0..output.nrows() {
482 if output[[i, 0]] == 0.0 && output[[i, 1]] == 0.0 {
483 dropped_rows += 1;
484 }
485 }
486
487 assert!(dropped_rows > 0);
489 }
490
491 #[test]
492 fn test_drop_path_set_prob() {
493 let mut dp = DropPath::new(0.2).unwrap();
494 assert_eq!(dp.drop_prob, 0.2);
495
496 dp.set_drop_prob(0.5).unwrap();
497 assert_eq!(dp.drop_prob, 0.5);
498 assert!((dp.keep_prob - 0.5).abs() < 1e-10);
499
500 assert!(dp.set_drop_prob(1.5).is_err());
502 }
503
504 #[test]
505 fn test_linear_stochastic_depth_creation() {
506 let scheduler = LinearStochasticDepth::new(10, 0.0, 0.5).unwrap();
507 assert_eq!(scheduler.num_layers, 10);
508 assert_eq!(scheduler.drop_prob_min, 0.0);
509 assert_eq!(scheduler.drop_prob_max, 0.5);
510 }
511
512 #[test]
513 fn test_linear_stochastic_depth_invalid() {
514 assert!(LinearStochasticDepth::new(0, 0.0, 0.5).is_err());
515 assert!(LinearStochasticDepth::new(10, -0.1, 0.5).is_err());
516 assert!(LinearStochasticDepth::new(10, 0.0, 1.5).is_err());
517 assert!(LinearStochasticDepth::new(10, 0.6, 0.3).is_err());
518 }
519
520 #[test]
521 fn test_linear_stochastic_depth_interpolation() {
522 let scheduler = LinearStochasticDepth::new(10, 0.0, 0.9).unwrap();
523
524 assert!((scheduler.get_drop_prob(0) - 0.0).abs() < 1e-10);
526
527 assert!((scheduler.get_drop_prob(5) - 0.5).abs() < 1e-6);
529
530 assert!((scheduler.get_drop_prob(9) - 0.9).abs() < 1e-10);
532 }
533
534 #[test]
535 fn test_linear_stochastic_depth_create_paths() {
536 let scheduler = LinearStochasticDepth::new(5, 0.0, 0.4).unwrap();
537 let paths = scheduler.create_drop_paths().unwrap();
538
539 assert_eq!(paths.len(), 5);
540
541 assert!((paths[0].drop_prob - 0.0).abs() < 1e-10);
543 assert!((paths[2].drop_prob - 0.2).abs() < 1e-10);
544 assert!((paths[4].drop_prob - 0.4).abs() < 1e-10);
545 }
546
547 #[test]
548 fn test_exponential_stochastic_depth() {
549 let scheduler = ExponentialStochasticDepth::new(10, 0.0, 0.8).unwrap();
550
551 assert!((scheduler.get_drop_prob(0) - 0.0).abs() < 1e-10);
553
554 assert!((scheduler.get_drop_prob(9) - 0.8).abs() < 1e-10);
556
557 let mid_prob = scheduler.get_drop_prob(5);
559 let linear_mid = 0.4; assert!(mid_prob < linear_mid + 0.1);
563 }
564
565 #[test]
566 fn test_exponential_create_paths() {
567 let scheduler = ExponentialStochasticDepth::new(5, 0.0, 0.4).unwrap();
568 let paths = scheduler.create_drop_paths().unwrap();
569
570 assert_eq!(paths.len(), 5);
571
572 for i in 0..paths.len() - 1 {
574 assert!(paths[i].drop_prob <= paths[i + 1].drop_prob);
575 }
576 }
577}