1use crate::error::{NeuralError, Result};
9use crate::layers::conv::PaddingMode;
10use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, NumAssign};
13use scirs2_core::random::{rngs::SmallRng, RngExt, SeedableRng};
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17#[allow(dead_code)]
19pub fn swish<F: Float>(x: F) -> F {
20 x * (F::one() + (-x).exp()).recip()
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct MBConvConfig {
26 pub input_channels: usize,
28 pub output_channels: usize,
30 pub kernel_size: usize,
32 pub stride: usize,
34 pub expand_ratio: usize,
36 pub use_se: bool,
38 pub drop_connect_rate: f64,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct EfficientNetStage {
45 pub mbconv_config: MBConvConfig,
47 pub num_blocks: usize,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct EfficientNetConfig {
54 pub width_coefficient: f64,
56 pub depth_coefficient: f64,
58 pub resolution: usize,
60 pub dropout_rate: f64,
62 pub stages: Vec<EfficientNetStage>,
64 pub input_channels: usize,
66 pub num_classes: usize,
68}
69
70impl EfficientNetConfig {
71 pub fn efficientnet_b0(input_channels: usize, num_classes: usize) -> Self {
73 let stages = vec![
74 EfficientNetStage {
76 mbconv_config: MBConvConfig {
77 input_channels: 32,
78 output_channels: 16,
79 kernel_size: 3,
80 stride: 1,
81 expand_ratio: 1,
82 use_se: true,
83 drop_connect_rate: 0.2,
84 },
85 num_blocks: 1,
86 },
87 EfficientNetStage {
89 mbconv_config: MBConvConfig {
90 input_channels: 16,
91 output_channels: 24,
92 kernel_size: 3,
93 stride: 2,
94 expand_ratio: 6,
95 use_se: true,
96 drop_connect_rate: 0.2,
97 },
98 num_blocks: 2,
99 },
100 EfficientNetStage {
102 mbconv_config: MBConvConfig {
103 input_channels: 24,
104 output_channels: 40,
105 kernel_size: 5,
106 stride: 2,
107 expand_ratio: 6,
108 use_se: true,
109 drop_connect_rate: 0.2,
110 },
111 num_blocks: 2,
112 },
113 EfficientNetStage {
115 mbconv_config: MBConvConfig {
116 input_channels: 40,
117 output_channels: 80,
118 kernel_size: 3,
119 stride: 2,
120 expand_ratio: 6,
121 use_se: true,
122 drop_connect_rate: 0.2,
123 },
124 num_blocks: 3,
125 },
126 EfficientNetStage {
128 mbconv_config: MBConvConfig {
129 input_channels: 80,
130 output_channels: 112,
131 kernel_size: 5,
132 stride: 1,
133 expand_ratio: 6,
134 use_se: true,
135 drop_connect_rate: 0.2,
136 },
137 num_blocks: 3,
138 },
139 EfficientNetStage {
141 mbconv_config: MBConvConfig {
142 input_channels: 112,
143 output_channels: 192,
144 kernel_size: 5,
145 stride: 2,
146 expand_ratio: 6,
147 use_se: true,
148 drop_connect_rate: 0.2,
149 },
150 num_blocks: 4,
151 },
152 EfficientNetStage {
154 mbconv_config: MBConvConfig {
155 input_channels: 192,
156 output_channels: 320,
157 kernel_size: 3,
158 stride: 1,
159 expand_ratio: 6,
160 use_se: true,
161 drop_connect_rate: 0.2,
162 },
163 num_blocks: 1,
164 },
165 ];
166 Self {
167 width_coefficient: 1.0,
168 depth_coefficient: 1.0,
169 resolution: 224,
170 dropout_rate: 0.2,
171 stages,
172 input_channels,
173 num_classes,
174 }
175 }
176
177 pub fn efficientnet_b1(input_channels: usize, num_classes: usize) -> Self {
179 let mut config = Self::efficientnet_b0(input_channels, num_classes);
180 config.width_coefficient = 1.0;
181 config.depth_coefficient = 1.1;
182 config.resolution = 240;
183 config.dropout_rate = 0.2;
184 config
185 }
186
187 pub fn efficientnet_b2(input_channels: usize, num_classes: usize) -> Self {
189 let mut config = Self::efficientnet_b0(input_channels, num_classes);
190 config.width_coefficient = 1.1;
191 config.depth_coefficient = 1.2;
192 config.resolution = 260;
193 config.dropout_rate = 0.3;
194 config
195 }
196
197 pub fn efficientnet_b3(input_channels: usize, num_classes: usize) -> Self {
199 let mut config = Self::efficientnet_b0(input_channels, num_classes);
200 config.width_coefficient = 1.2;
201 config.depth_coefficient = 1.4;
202 config.resolution = 300;
203 config.dropout_rate = 0.3;
204 config
205 }
206
207 pub fn efficientnet_b4(input_channels: usize, num_classes: usize) -> Self {
209 let mut config = Self::efficientnet_b0(input_channels, num_classes);
210 config.width_coefficient = 1.4;
211 config.depth_coefficient = 1.8;
212 config.resolution = 380;
213 config.dropout_rate = 0.4;
214 config
215 }
216
217 pub fn efficientnet_b5(input_channels: usize, num_classes: usize) -> Self {
219 let mut config = Self::efficientnet_b0(input_channels, num_classes);
220 config.width_coefficient = 1.6;
221 config.depth_coefficient = 2.2;
222 config.resolution = 456;
223 config.dropout_rate = 0.4;
224 config
225 }
226
227 pub fn efficientnet_b6(input_channels: usize, num_classes: usize) -> Self {
229 let mut config = Self::efficientnet_b0(input_channels, num_classes);
230 config.width_coefficient = 1.8;
231 config.depth_coefficient = 2.6;
232 config.resolution = 528;
233 config.dropout_rate = 0.5;
234 config
235 }
236
237 pub fn efficientnet_b7(input_channels: usize, num_classes: usize) -> Self {
239 let mut config = Self::efficientnet_b0(input_channels, num_classes);
240 config.width_coefficient = 2.0;
241 config.depth_coefficient = 3.1;
242 config.resolution = 600;
243 config.dropout_rate = 0.5;
244 config
245 }
246
247 pub fn scale_channels(&self, channels: usize) -> usize {
249 let scaled = (channels as f64 * self.width_coefficient).round();
250 (scaled as usize).div_ceil(8) * 8
252 }
253
254 pub fn scale_depth(&self, depth: usize) -> usize {
256 (depth as f64 * self.depth_coefficient).ceil() as usize
257 }
258}
259
260struct SqueezeExcitation<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
262 input_channels: usize,
263 #[allow(dead_code)]
265 squeeze_channels: usize,
266 fc1: Conv2D<F>,
268 fc2: Conv2D<F>,
270}
271
272impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> SqueezeExcitation<F> {
273 pub fn new(input_channels: usize, squeeze_channels: usize) -> Result<Self> {
275 let fc1 = Conv2D::new(input_channels, squeeze_channels, (1, 1), (1, 1), None)?
277 .with_padding(PaddingMode::Valid);
278 let fc2 = Conv2D::new(squeeze_channels, input_channels, (1, 1), (1, 1), None)?
280 .with_padding(PaddingMode::Valid);
281 Ok(Self {
282 input_channels,
283 squeeze_channels,
284 fc1,
285 fc2,
286 })
287 }
288}
289
290impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for SqueezeExcitation<F> {
291 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
292 let shape = input.shape();
294 if shape.len() != 4 {
295 return Err(NeuralError::InferenceError(format!(
296 "Expected 4D input, got {:?}",
297 shape
298 )));
299 }
300 let batch_size = shape[0];
301 let channels = shape[1];
302 let height = shape[2];
303 let width = shape[3];
304 if channels != self.input_channels {
305 return Err(NeuralError::InferenceError(format!(
306 "Expected {} input channels, got {}",
307 self.input_channels, channels
308 )));
309 }
310 let spatial_size = F::from(height * width).ok_or_else(|| {
312 NeuralError::InferenceError("Failed to convert spatial size".to_string())
313 })?;
314 let mut x = Array::zeros(IxDyn(&[batch_size, channels, 1, 1]));
315 for b in 0..batch_size {
316 for c in 0..channels {
317 let mut sum = F::zero();
318 for h in 0..height {
319 for w in 0..width {
320 sum += input[[b, c, h, w]];
321 }
322 }
323 x[[b, c, 0, 0]] = sum / spatial_size;
324 }
325 }
326 let x = self.fc1.forward(&x)?;
328 let x = x.mapv(|v: F| v.max(F::zero()));
330 let x = self.fc2.forward(&x)?;
332 let x = x.mapv(|v| F::one() / (F::one() + (-v).exp()));
334 let mut result = input.clone();
336 for b in 0..batch_size {
337 for c in 0..channels {
338 let scale = x[[b, c, 0, 0]];
339 for h in 0..height {
340 for w in 0..width {
341 result[[b, c, h, w]] = input[[b, c, h, w]] * scale;
342 }
343 }
344 }
345 }
346 Ok(result)
347 }
348
349 fn backward(
350 &self,
351 input: &Array<F, IxDyn>,
352 grad_output: &Array<F, IxDyn>,
353 ) -> Result<Array<F, IxDyn>> {
354 let shape = input.shape();
357 if shape.len() != 4 {
358 return Ok(grad_output.clone());
359 }
360 let batch_size = shape[0];
361 let channels = shape[1];
362 let height = shape[2];
363 let width = shape[3];
364
365 let spatial_size = F::from(height * width).ok_or_else(|| {
367 NeuralError::InferenceError("Failed to convert spatial size".to_string())
368 })?;
369 let mut pooled = Array::zeros(IxDyn(&[batch_size, channels, 1, 1]));
370 for b in 0..batch_size {
371 for c in 0..channels {
372 let mut sum = F::zero();
373 for h in 0..height {
374 for w in 0..width {
375 sum += input[[b, c, h, w]];
376 }
377 }
378 pooled[[b, c, 0, 0]] = sum / spatial_size;
379 }
380 }
381 let squeezed = self.fc1.forward(&pooled)?;
382 let relu_out = squeezed.mapv(|v: F| v.max(F::zero()));
383 let excited = self.fc2.forward(&relu_out)?;
384 let scale = excited.mapv(|v| F::one() / (F::one() + (-v).exp()));
385
386 let mut grad_input = grad_output.clone();
388 for b in 0..batch_size {
389 for c in 0..channels {
390 let s = scale[[b, c, 0, 0]];
391 for h in 0..height {
392 for w in 0..width {
393 grad_input[[b, c, h, w]] = grad_output[[b, c, h, w]] * s;
394 }
395 }
396 }
397 }
398 Ok(grad_input)
399 }
400
401 fn update(&mut self, learning_rate: F) -> Result<()> {
402 self.fc1.update(learning_rate)?;
403 self.fc2.update(learning_rate)?;
404 Ok(())
405 }
406
407 fn as_any(&self) -> &dyn std::any::Any {
408 self
409 }
410
411 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
412 self
413 }
414}
415
416struct MBConvBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
418 #[allow(dead_code)]
420 config: MBConvConfig,
421 has_skip_connection: bool,
423 expand_conv: Option<Conv2D<F>>,
425 expand_bn: Option<BatchNorm<F>>,
427 depthwise_conv: Conv2D<F>,
429 depthwise_bn: BatchNorm<F>,
431 se: Option<SqueezeExcitation<F>>,
433 project_conv: Conv2D<F>,
435 project_bn: BatchNorm<F>,
437 drop_connect_rate: F,
439}
440
441impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> MBConvBlock<F> {
442 pub fn new(config: MBConvConfig) -> Result<Self> {
444 let input_channels = config.input_channels;
445 let output_channels = config.output_channels;
446 let expand_ratio = config.expand_ratio;
447 let kernel_size = config.kernel_size;
448 let stride = config.stride;
449 let use_se = config.use_se;
450 let drop_connect_rate = F::from(config.drop_connect_rate).ok_or_else(|| {
451 NeuralError::InvalidArchitecture("Failed to convert drop_connect_rate".to_string())
452 })?;
453
454 let mut rng = SmallRng::from_seed([42; 32]);
455
456 let has_skip_connection = input_channels == output_channels && stride == 1;
458
459 let (expand_conv, expand_bn) = if expand_ratio != 1 {
461 let expanded_channels = input_channels * expand_ratio;
462 let conv = Conv2D::new(input_channels, expanded_channels, (1, 1), (1, 1), None)?
464 .with_padding(PaddingMode::Valid);
465 let bn = BatchNorm::new(expanded_channels, 1e-3, 0.01, &mut rng)?;
467 (Some(conv), Some(bn))
468 } else {
469 (None, None)
470 };
471
472 let expanded_channels = if expand_ratio != 1 {
474 input_channels * expand_ratio
475 } else {
476 input_channels
477 };
478
479 let depthwise_conv = Conv2D::new(
481 expanded_channels,
482 expanded_channels,
483 (kernel_size, kernel_size),
484 (stride, stride),
485 None,
486 )?
487 .with_padding(PaddingMode::Same);
488
489 let depthwise_bn = BatchNorm::new(expanded_channels, 1e-3, 0.01, &mut rng)?;
491
492 let se = if use_se {
494 let squeeze_channels = (expanded_channels as f64 / 4.0).round() as usize;
495 Some(SqueezeExcitation::new(expanded_channels, squeeze_channels)?)
496 } else {
497 None
498 };
499
500 let project_conv = Conv2D::new(expanded_channels, output_channels, (1, 1), (1, 1), None)?
502 .with_padding(PaddingMode::Valid);
503
504 let project_bn = BatchNorm::new(output_channels, 1e-3, 0.01, &mut rng)?;
506
507 Ok(Self {
508 config,
509 has_skip_connection,
510 expand_conv,
511 expand_bn,
512 depthwise_conv,
513 depthwise_bn,
514 se,
515 project_conv,
516 project_bn,
517 drop_connect_rate,
518 })
519 }
520
521 fn drop_connect<R: scirs2_core::random::Rng>(
523 &self,
524 input: &Array<F, IxDyn>,
525 rng: &mut R,
526 ) -> Array<F, IxDyn> {
527 if self.drop_connect_rate <= F::zero() || !self.has_skip_connection {
528 return input.clone();
529 }
530
531 let shape = input.shape();
532 let mut result = input.clone();
533
534 let keep_prob = F::one() - self.drop_connect_rate;
536 if rng.random::<f64>() > self.drop_connect_rate.to_f64().unwrap_or(0.0) {
537 result = result.mapv(|x| x / keep_prob);
539 } else {
540 result = Array::zeros(IxDyn(shape));
542 }
543 result
544 }
545}
546
547impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for MBConvBlock<F> {
548 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
549 let mut rng = SmallRng::from_seed([42; 32]);
550 let mut x = input.clone();
551
552 if let (Some(ref expand_conv), Some(ref expand_bn)) = (&self.expand_conv, &self.expand_bn) {
554 x = expand_conv.forward(&x)?;
555 x = expand_bn.forward(&x)?;
556 x = x.mapv(swish); }
558
559 x = self.depthwise_conv.forward(&x)?;
561 x = self.depthwise_bn.forward(&x)?;
562 x = x.mapv(swish); if let Some(ref se) = self.se {
566 x = se.forward(&x)?;
567 }
568
569 x = self.project_conv.forward(&x)?;
571 x = self.project_bn.forward(&x)?;
572
573 if self.has_skip_connection {
575 x = self.drop_connect(&x, &mut rng);
577 let mut result = input.clone();
579 for i in 0..result.len() {
580 result[i] += x[i];
581 }
582 x = result;
583 }
584
585 Ok(x)
586 }
587
588 fn backward(
589 &self,
590 input: &Array<F, IxDyn>,
591 grad_output: &Array<F, IxDyn>,
592 ) -> Result<Array<F, IxDyn>> {
593 let mut grad = grad_output.clone();
596
597 grad = self.project_bn.backward(input, &grad)?;
599 grad = self.project_conv.backward(input, &grad)?;
601
602 if let Some(ref se) = self.se {
604 grad = se.backward(input, &grad)?;
605 }
606
607 grad = self.depthwise_bn.backward(input, &grad)?;
610 grad = self.depthwise_conv.backward(input, &grad)?;
611
612 if let (Some(ref expand_conv), Some(ref expand_bn)) = (&self.expand_conv, &self.expand_bn) {
614 grad = expand_bn.backward(input, &grad)?;
615 grad = expand_conv.backward(input, &grad)?;
616 }
617
618 if self.has_skip_connection {
620 let mut result = grad_output.clone();
622 for i in 0..result.len() {
623 result[i] += grad[i];
624 }
625 return Ok(result);
626 }
627
628 Ok(grad)
629 }
630
631 fn update(&mut self, learning_rate: F) -> Result<()> {
632 if let (Some(ref mut expand_conv), Some(ref mut expand_bn)) =
634 (&mut self.expand_conv, &mut self.expand_bn)
635 {
636 expand_conv.update(learning_rate)?;
637 expand_bn.update(learning_rate)?;
638 }
639
640 self.depthwise_conv.update(learning_rate)?;
642 self.depthwise_bn.update(learning_rate)?;
643
644 if let Some(ref mut se) = self.se {
646 se.update(learning_rate)?;
647 }
648
649 self.project_conv.update(learning_rate)?;
651 self.project_bn.update(learning_rate)?;
652
653 Ok(())
654 }
655
656 fn as_any(&self) -> &dyn std::any::Any {
657 self
658 }
659
660 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
661 self
662 }
663}
664
665pub struct EfficientNet<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
667 config: EfficientNetConfig,
669 stem_conv: Conv2D<F>,
671 stem_bn: BatchNorm<F>,
673 blocks: Vec<MBConvBlock<F>>,
675 head_conv: Conv2D<F>,
677 head_bn: BatchNorm<F>,
679 classifier: Dense<F>,
681 dropout: Dropout<F>,
683}
684
685impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> EfficientNet<F> {
686 pub fn new(config: EfficientNetConfig) -> Result<Self> {
688 let mut rng = SmallRng::from_seed([42; 32]);
689 let num_classes = config.num_classes;
690 let input_channels = config.input_channels;
691
692 let stem_channels = config.scale_channels(32);
694 let stem_conv = Conv2D::new(input_channels, stem_channels, (3, 3), (2, 2), None)?
695 .with_padding(PaddingMode::Same);
696
697 let stem_bn = BatchNorm::new(stem_channels, 1e-3, 0.01, &mut rng)?;
699
700 let mut blocks = Vec::new();
702 let mut in_channels = stem_channels;
703 for stage in &config.stages {
704 let num_blocks = config.scale_depth(stage.num_blocks);
706 let out_channels = config.scale_channels(stage.mbconv_config.output_channels);
707
708 let first_block_config = MBConvConfig {
710 input_channels: in_channels,
711 output_channels: out_channels,
712 kernel_size: stage.mbconv_config.kernel_size,
713 stride: stage.mbconv_config.stride,
714 expand_ratio: stage.mbconv_config.expand_ratio,
715 use_se: stage.mbconv_config.use_se,
716 drop_connect_rate: stage.mbconv_config.drop_connect_rate,
717 };
718 blocks.push(MBConvBlock::new(first_block_config)?);
719
720 for _ in 1..num_blocks {
722 let block_config = MBConvConfig {
723 input_channels: out_channels,
724 output_channels: out_channels,
725 kernel_size: stage.mbconv_config.kernel_size,
726 stride: 1,
727 expand_ratio: stage.mbconv_config.expand_ratio,
728 use_se: stage.mbconv_config.use_se,
729 drop_connect_rate: stage.mbconv_config.drop_connect_rate,
730 };
731 blocks.push(MBConvBlock::new(block_config)?);
732 }
733
734 in_channels = out_channels;
735 }
736
737 let head_channels = config.scale_channels(1280);
739 let head_conv = Conv2D::new(in_channels, head_channels, (1, 1), (1, 1), None)?
740 .with_padding(PaddingMode::Valid);
741
742 let head_bn = BatchNorm::new(head_channels, 1e-3, 0.01, &mut rng)?;
744
745 let classifier = Dense::new(head_channels, num_classes, None, &mut rng)?;
747
748 let dropout = Dropout::new(config.dropout_rate, &mut rng)?;
750
751 Ok(Self {
752 config,
753 stem_conv,
754 stem_bn,
755 blocks,
756 head_conv,
757 head_bn,
758 classifier,
759 dropout,
760 })
761 }
762
763 pub fn efficientnet_b0(input_channels: usize, num_classes: usize) -> Result<Self> {
765 let config = EfficientNetConfig::efficientnet_b0(input_channels, num_classes);
766 Self::new(config)
767 }
768
769 pub fn efficientnet_b1(input_channels: usize, num_classes: usize) -> Result<Self> {
771 let config = EfficientNetConfig::efficientnet_b1(input_channels, num_classes);
772 Self::new(config)
773 }
774
775 pub fn efficientnet_b2(input_channels: usize, num_classes: usize) -> Result<Self> {
777 let config = EfficientNetConfig::efficientnet_b2(input_channels, num_classes);
778 Self::new(config)
779 }
780
781 pub fn efficientnet_b3(input_channels: usize, num_classes: usize) -> Result<Self> {
783 let config = EfficientNetConfig::efficientnet_b3(input_channels, num_classes);
784 Self::new(config)
785 }
786
787 pub fn efficientnet_b4(input_channels: usize, num_classes: usize) -> Result<Self> {
789 let config = EfficientNetConfig::efficientnet_b4(input_channels, num_classes);
790 Self::new(config)
791 }
792
793 pub fn efficientnet_b5(input_channels: usize, num_classes: usize) -> Result<Self> {
795 let config = EfficientNetConfig::efficientnet_b5(input_channels, num_classes);
796 Self::new(config)
797 }
798
799 pub fn efficientnet_b6(input_channels: usize, num_classes: usize) -> Result<Self> {
801 let config = EfficientNetConfig::efficientnet_b6(input_channels, num_classes);
802 Self::new(config)
803 }
804
805 pub fn efficientnet_b7(input_channels: usize, num_classes: usize) -> Result<Self> {
807 let config = EfficientNetConfig::efficientnet_b7(input_channels, num_classes);
808 Self::new(config)
809 }
810
811 pub fn config(&self) -> &EfficientNetConfig {
813 &self.config
814 }
815}
816
817impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for EfficientNet<F> {
818 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
819 let shape = input.shape();
820 if shape.len() != 4 || shape[1] != self.config.input_channels {
822 return Err(NeuralError::InferenceError(format!(
823 "Expected input shape [batch_size, {}, height, width], got {:?}",
824 self.config.input_channels, shape
825 )));
826 }
827
828 let batch_size = shape[0];
829
830 let mut x = self.stem_conv.forward(input)?;
832 x = self.stem_bn.forward(&x)?;
833 x = x.mapv(swish); for block in &self.blocks {
837 x = block.forward(&x)?;
838 }
839
840 x = self.head_conv.forward(&x)?;
842 x = self.head_bn.forward(&x)?;
843 x = x.mapv(swish); let channels = x.shape()[1];
847 let height = x.shape()[2];
848 let width = x.shape()[3];
849 let mut pooled = Array::zeros(IxDyn(&[batch_size, channels]));
850 for b in 0..batch_size {
851 for c in 0..channels {
852 let mut sum = F::zero();
853 for h in 0..height {
854 for w in 0..width {
855 sum += x[[b, c, h, w]];
856 }
857 }
858 pooled[[b, c]] = sum / F::from(height * width).unwrap_or(F::one());
859 }
860 }
861
862 let pooled = self.dropout.forward(&pooled)?;
864 let logits = self.classifier.forward(&pooled)?;
865
866 Ok(logits)
867 }
868
869 fn backward(
870 &self,
871 input: &Array<F, IxDyn>,
872 grad_output: &Array<F, IxDyn>,
873 ) -> Result<Array<F, IxDyn>> {
874 let mut grad = self.classifier.backward(input, grad_output)?;
876
877 grad = self.dropout.backward(input, &grad)?;
879
880 grad = self.head_bn.backward(input, &grad)?;
882 grad = self.head_conv.backward(input, &grad)?;
883
884 for block in self.blocks.iter().rev() {
886 grad = block.backward(input, &grad)?;
887 }
888
889 grad = self.stem_bn.backward(input, &grad)?;
891 grad = self.stem_conv.backward(input, &grad)?;
892
893 Ok(grad)
894 }
895
896 fn update(&mut self, learning_rate: F) -> Result<()> {
897 self.stem_conv.update(learning_rate)?;
899 self.stem_bn.update(learning_rate)?;
900
901 for block in &mut self.blocks {
903 block.update(learning_rate)?;
904 }
905
906 self.head_conv.update(learning_rate)?;
908 self.head_bn.update(learning_rate)?;
909
910 self.classifier.update(learning_rate)?;
912
913 Ok(())
914 }
915
916 fn as_any(&self) -> &dyn std::any::Any {
917 self
918 }
919
920 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
921 self
922 }
923}
924
925#[cfg(test)]
926mod tests {
927 use super::*;
928 use scirs2_core::ndarray::Array4;
929
930 fn minimal_efficientnet_config(
936 input_channels: usize,
937 num_classes: usize,
938 ) -> EfficientNetConfig {
939 EfficientNetConfig {
940 width_coefficient: 1.0,
941 depth_coefficient: 1.0,
942 resolution: 224,
943 dropout_rate: 0.2,
944 stages: vec![EfficientNetStage {
945 mbconv_config: MBConvConfig {
946 input_channels: 8,
947 output_channels: 8,
948 kernel_size: 3,
949 stride: 1,
950 expand_ratio: 1,
951 use_se: false,
952 drop_connect_rate: 0.0,
953 },
954 num_blocks: 1,
955 }],
956 input_channels,
957 num_classes,
958 }
959 }
960
961 #[test]
962 fn test_efficientnet_b0_creation() {
963 let config = EfficientNetConfig::efficientnet_b0(3, 10);
968 assert_eq!(config.resolution, 224);
969 assert_eq!(config.num_classes, 10);
970 assert_eq!(config.input_channels, 3);
971 assert_eq!(config.stages.len(), 7);
972 assert!((config.width_coefficient - 1.0).abs() < f64::EPSILON);
973 assert!((config.depth_coefficient - 1.0).abs() < f64::EPSILON);
974
975 let result = EfficientNet::<f32>::new(minimal_efficientnet_config(3, 10));
977 assert!(result.is_ok());
978 }
979
980 #[test]
981 fn test_efficientnet_config_scaling() {
982 let config = EfficientNetConfig::efficientnet_b0(3, 10);
983 let scaled = config.scale_channels(32);
984 assert_eq!(scaled % 8, 0);
985 assert_eq!(scaled, 32);
986
987 let config_b3 = EfficientNetConfig::efficientnet_b3(3, 10);
988 let scaled_b3 = config_b3.scale_channels(32);
989 assert_eq!(scaled_b3 % 8, 0);
990 assert!(scaled_b3 >= 32);
991
992 let depth_scaled = config_b3.scale_depth(2);
993 assert_eq!(depth_scaled, 3);
994 }
995
996 #[test]
997 fn test_efficientnet_all_variants() {
998 let configs = [
999 EfficientNetConfig::efficientnet_b0(3, 10),
1000 EfficientNetConfig::efficientnet_b1(3, 10),
1001 EfficientNetConfig::efficientnet_b2(3, 10),
1002 EfficientNetConfig::efficientnet_b3(3, 10),
1003 EfficientNetConfig::efficientnet_b4(3, 10),
1004 EfficientNetConfig::efficientnet_b5(3, 10),
1005 EfficientNetConfig::efficientnet_b6(3, 10),
1006 EfficientNetConfig::efficientnet_b7(3, 10),
1007 ];
1008
1009 let expected_resolutions = [224, 240, 260, 300, 380, 456, 528, 600];
1010 for (i, config) in configs.iter().enumerate() {
1011 assert_eq!(
1012 config.resolution, expected_resolutions[i],
1013 "B{} resolution mismatch",
1014 i
1015 );
1016 assert_eq!(config.stages.len(), 7, "B{} should have 7 stages", i);
1017 }
1018 }
1019
1020 #[test]
1021 fn test_squeeze_excitation_forward() {
1022 let channels = 16;
1023 let se = SqueezeExcitation::<f64>::new(channels, 4).expect("Test: SE creation");
1024
1025 let input = Array4::<f64>::from_elem((1, channels, 2, 2), 0.5).into_dyn();
1026 let output = se.forward(&input);
1027 assert!(output.is_ok());
1028 let out = output.expect("Test: SE forward");
1029 assert_eq!(out.shape(), input.shape());
1030 assert!(out.iter().all(|&v| v.is_finite()));
1031 }
1032
1033 #[test]
1034 fn test_squeeze_excitation_backward() {
1035 let channels = 8;
1036 let se = SqueezeExcitation::<f64>::new(channels, 2).expect("Test: SE creation");
1037
1038 let input = Array4::<f64>::from_elem((1, channels, 2, 2), 0.3).into_dyn();
1039 let grad_output = Array4::<f64>::from_elem((1, channels, 2, 2), 0.1).into_dyn();
1040
1041 let grad_input = se.backward(&input, &grad_output);
1042 assert!(grad_input.is_ok());
1043 let gi = grad_input.expect("Test: SE backward");
1044 assert_eq!(gi.shape(), input.shape());
1045 assert!(gi.iter().all(|&v| v.is_finite()));
1046 }
1047
1048 #[test]
1049 fn test_mbconv_block_creation() {
1050 let config = MBConvConfig {
1051 input_channels: 16,
1052 output_channels: 24,
1053 kernel_size: 3,
1054 stride: 1,
1055 expand_ratio: 6,
1056 use_se: true,
1057 drop_connect_rate: 0.2,
1058 };
1059 let block = MBConvBlock::<f64>::new(config);
1060 assert!(block.is_ok());
1061 }
1062
1063 #[test]
1064 fn test_mbconv_skip_connection() {
1065 let config_skip = MBConvConfig {
1066 input_channels: 16,
1067 output_channels: 16,
1068 kernel_size: 3,
1069 stride: 1,
1070 expand_ratio: 1,
1071 use_se: false,
1072 drop_connect_rate: 0.0,
1073 };
1074 let block = MBConvBlock::<f64>::new(config_skip).expect("Test: MBConv skip creation");
1075 assert!(block.has_skip_connection);
1076
1077 let config_no_skip = MBConvConfig {
1078 input_channels: 16,
1079 output_channels: 16,
1080 kernel_size: 3,
1081 stride: 2,
1082 expand_ratio: 1,
1083 use_se: false,
1084 drop_connect_rate: 0.0,
1085 };
1086 let block_ns =
1087 MBConvBlock::<f64>::new(config_no_skip).expect("Test: MBConv no-skip creation");
1088 assert!(!block_ns.has_skip_connection);
1089 }
1090
1091 #[test]
1092 fn test_se_invalid_input_dims() {
1093 let se = SqueezeExcitation::<f64>::new(8, 2).expect("Test: SE creation");
1094 let bad_input = Array::zeros(IxDyn(&[1, 8, 4]));
1095 assert!(se.forward(&bad_input).is_err());
1096 }
1097
1098 #[test]
1099 fn test_se_channel_mismatch() {
1100 let se = SqueezeExcitation::<f64>::new(8, 2).expect("Test: SE creation");
1101 let bad_input = Array4::<f64>::zeros((1, 4, 2, 2)).into_dyn();
1102 assert!(se.forward(&bad_input).is_err());
1103 }
1104
1105 #[test]
1106 fn test_swish_activation() {
1107 assert!((swish(0.0_f64)).abs() < 1e-10);
1108 let large_val = swish(10.0_f64);
1109 assert!((large_val - 10.0).abs() < 0.01);
1110 let neg_val = swish(-5.0_f64);
1111 assert!(neg_val < 0.0);
1112 assert!(neg_val > -1.0);
1113 }
1114
1115 #[test]
1116 fn test_efficientnet_b0_forward_stem() {
1117 let config = minimal_efficientnet_config(3, 10);
1122 let b0_config = EfficientNetConfig::efficientnet_b0(3, 10);
1125 assert_eq!(b0_config.resolution, 224);
1126 assert_eq!(b0_config.num_classes, 10);
1127 assert_eq!(b0_config.input_channels, 3);
1128 assert_eq!(b0_config.stages.len(), 7);
1129
1130 let model = EfficientNet::<f32>::new(config).expect("Test: minimal model creation");
1132 let stem_input = Array4::<f32>::from_elem((1, 3, 8, 8), 0.1_f32).into_dyn();
1133 let stem_output = model.stem_conv.forward(&stem_input);
1134 assert!(stem_output.is_ok(), "stem conv forward should succeed");
1135 let out = stem_output.expect("Test: stem forward");
1136 assert_eq!(out.shape()[0], 1, "batch size preserved");
1138 assert!(
1139 out.iter().all(|v| v.is_finite()),
1140 "no NaN/Inf in stem output"
1141 );
1142 }
1143
1144 #[test]
1145 fn test_efficientnet_invalid_input() {
1146 let config = minimal_efficientnet_config(3, 10);
1150 let model = EfficientNet::<f32>::new(config).expect("Test: minimal model creation");
1151
1152 let bad_input = Array4::<f32>::from_elem((1, 1, 8, 8), 0.1_f32).into_dyn();
1154 assert!(
1155 model.forward(&bad_input).is_err(),
1156 "wrong channel count should return Err"
1157 );
1158
1159 let bad_dims = Array::zeros(IxDyn(&[1_usize, 3, 8]));
1161 assert!(
1162 model.forward(&bad_dims).is_err(),
1163 "3D input should return Err"
1164 );
1165 }
1166}