1use crate::activations::GELU;
9use crate::error::{NeuralError, Result};
10use crate::layers::conv::PaddingMode;
11use crate::layers::{Conv2D, Dense, Dropout, GlobalAvgPool2D, Layer, LayerNorm2D, Sequential};
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, NumAssign};
14use scirs2_core::random::{rngs::SmallRng, SeedableRng};
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ConvNeXtStageConfig {
21 pub input_channels: usize,
23 pub output_channels: usize,
25 pub num_blocks: usize,
27 pub stride: usize,
29 pub layer_scale_init_value: f64,
31 pub drop_path_prob: f64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ConvNeXtConfig {
38 pub variant: ConvNeXtVariant,
40 pub input_channels: usize,
42 pub depths: Vec<usize>,
44 pub dims: Vec<usize>,
46 pub num_classes: usize,
48 pub dropout_rate: Option<f64>,
50 pub layer_scale_init_value: f64,
52 pub include_top: bool,
54}
55
56impl Default for ConvNeXtConfig {
57 fn default() -> Self {
58 Self {
59 variant: ConvNeXtVariant::Tiny,
60 input_channels: 3,
61 depths: vec![3, 3, 9, 3],
62 dims: vec![96, 192, 384, 768],
63 num_classes: 1000,
64 dropout_rate: Some(0.0),
65 layer_scale_init_value: 1e-6,
66 include_top: true,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum ConvNeXtVariant {
74 Tiny,
76 Small,
78 Base,
80 Large,
82 XLarge,
84}
85
86#[derive(Debug, Clone)]
91pub struct ConvNeXtBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
92 pub depthwise_conv: Conv2D<F>,
94 pub norm: LayerNorm2D<F>,
96 pub pointwise_conv1: Conv2D<F>,
98 pub gelu: GELU,
100 pub pointwise_conv2: Conv2D<F>,
102 pub gamma: Array<F, IxDyn>,
104 pub use_skip: bool,
106 pub skip_scale: F,
108}
109
110impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtBlock<F> {
111 pub fn new(channels: usize, layer_scale_init_value: f64, drop_path_prob: f64) -> Result<Self> {
113 let depthwise_conv = Conv2D::<F>::new(channels, channels, (7, 7), (1, 1), None)
114 .map(|c| c.with_padding(PaddingMode::Custom(3)))?;
115
116 let norm = LayerNorm2D::<F>::new::<SmallRng>(channels, 1e-6, Some("norm"))?;
117
118 let pointwise_conv1 = Conv2D::<F>::new(channels, channels * 4, (1, 1), (1, 1), None)
119 .map(|c| c.with_padding(PaddingMode::Custom(0)))?;
120
121 let gelu = GELU::new();
122
123 let pointwise_conv2 = Conv2D::<F>::new(channels * 4, channels, (1, 1), (1, 1), None)
124 .map(|c| c.with_padding(PaddingMode::Custom(0)))?;
125
126 let gamma_value = F::from(layer_scale_init_value).ok_or_else(|| {
127 NeuralError::InvalidArchitecture(
128 "ConvNeXtBlock: failed to convert layer_scale_init_value to float".to_string(),
129 )
130 })?;
131 let gamma = Array::<F, _>::from_elem(IxDyn(&[channels]), gamma_value);
132
133 let skip_scale = F::from(1.0 - drop_path_prob).ok_or_else(|| {
134 NeuralError::InvalidArchitecture(
135 "ConvNeXtBlock: failed to convert drop_path_prob to float".to_string(),
136 )
137 })?;
138 let use_skip = drop_path_prob > 0.0;
139
140 Ok(Self {
141 depthwise_conv,
142 norm,
143 pointwise_conv1,
144 gelu,
145 pointwise_conv2,
146 gamma,
147 use_skip,
148 skip_scale,
149 })
150 }
151}
152
153impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
154 for ConvNeXtBlock<F>
155{
156 fn as_any(&self) -> &dyn std::any::Any {
157 self
158 }
159
160 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
161 self
162 }
163
164 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
165 let identity = input.clone();
166
167 let mut x = self.depthwise_conv.forward(input)?;
169 x = self.norm.forward(&x)?;
170
171 x = self.pointwise_conv1.forward(&x)?;
173 x = <GELU as Layer<F>>::forward(&self.gelu, &x)?;
174 x = self.pointwise_conv2.forward(&x)?;
175
176 let shape = x.shape().to_vec();
178 if shape.len() == 4 {
179 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
180 for ni in 0..n {
181 for ci in 0..c {
182 let g = self.gamma[ci];
183 for hi in 0..h {
184 for wi in 0..w {
185 x[[ni, ci, hi, wi]] *= g;
186 }
187 }
188 }
189 }
190 }
191
192 if self.use_skip {
194 x *= self.skip_scale;
195 }
196
197 Ok(x + identity)
198 }
199
200 fn backward(
201 &self,
202 input: &Array<F, IxDyn>,
203 grad_output: &Array<F, IxDyn>,
204 ) -> Result<Array<F, IxDyn>> {
205 let mut grad = grad_output.clone();
206 let grad_skip = grad.clone();
207
208 if self.use_skip {
209 grad *= self.skip_scale;
210 }
211
212 let shape = grad.shape().to_vec();
214 if shape.len() == 4 {
215 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
216 for ni in 0..n {
217 for ci in 0..c {
218 let g = self.gamma[ci];
219 for hi in 0..h {
220 for wi in 0..w {
221 grad[[ni, ci, hi, wi]] *= g;
222 }
223 }
224 }
225 }
226 }
227
228 let grad_after_conv2 = self.pointwise_conv2.backward(&grad, &grad)?;
229 let grad_after_gelu = grad_after_conv2.clone();
230 let grad_after_conv1 = self
231 .pointwise_conv1
232 .backward(&grad_after_gelu, &grad_after_gelu)?;
233 let grad_after_norm = self.norm.backward(&grad_after_conv1, &grad_after_conv1)?;
234 let grad_after_dwconv = self.depthwise_conv.backward(input, &grad_after_norm)?;
235
236 Ok(grad_after_dwconv + grad_skip)
237 }
238
239 fn update(&mut self, learning_rate: F) -> Result<()> {
240 self.depthwise_conv.update(learning_rate)?;
241 self.norm.update(learning_rate)?;
242 self.pointwise_conv1.update(learning_rate)?;
243 self.pointwise_conv2.update(learning_rate)?;
244
245 let small_update = F::from(0.0001_f64).ok_or_else(|| {
247 NeuralError::InvalidArchitecture(
248 "ConvNeXtBlock: failed to convert small_update to float".to_string(),
249 )
250 })? * learning_rate;
251 for elem in self.gamma.iter_mut() {
252 *elem -= small_update;
253 }
254 Ok(())
255 }
256
257 fn params(&self) -> Vec<Array<F, IxDyn>> {
258 let mut params = Vec::new();
259 params.extend(self.depthwise_conv.params());
260 params.extend(self.norm.params());
261 params.extend(self.pointwise_conv1.params());
262 params.extend(self.pointwise_conv2.params());
263 params.push(self.gamma.clone());
264 params
265 }
266
267 fn set_training(&mut self, training: bool) {
268 self.depthwise_conv.set_training(training);
269 self.norm.set_training(training);
270 self.pointwise_conv1.set_training(training);
271 self.pointwise_conv2.set_training(training);
272 <GELU as Layer<F>>::set_training(&mut self.gelu, training);
273 }
274
275 fn is_training(&self) -> bool {
276 self.depthwise_conv.is_training()
277 }
278
279 fn layer_type(&self) -> &str {
280 "ConvNeXtBlock"
281 }
282}
283
284#[derive(Debug, Clone)]
286pub struct ConvNeXtDownsample<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
287 pub norm: LayerNorm2D<F>,
289 pub conv: Conv2D<F>,
291}
292
293impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtDownsample<F> {
294 pub fn new(in_channels: usize, out_channels: usize, stride: usize) -> Result<Self> {
296 let norm = LayerNorm2D::<F>::new::<SmallRng>(in_channels, 1e-6, Some("downsample_norm"))?;
297 let conv = Conv2D::<F>::new(
298 in_channels,
299 out_channels,
300 (stride, stride),
301 (stride, stride),
302 None,
303 )
304 .map(|c| c.with_padding(PaddingMode::Custom(0)))?;
305 Ok(Self { norm, conv })
306 }
307}
308
309impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
310 for ConvNeXtDownsample<F>
311{
312 fn as_any(&self) -> &dyn std::any::Any {
313 self
314 }
315
316 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
317 self
318 }
319
320 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
321 let x = self.norm.forward(input)?;
322 self.conv.forward(&x)
323 }
324
325 fn backward(
326 &self,
327 input: &Array<F, IxDyn>,
328 grad_output: &Array<F, IxDyn>,
329 ) -> Result<Array<F, IxDyn>> {
330 let grad_after_conv = self.conv.backward(grad_output, grad_output)?;
331 self.norm.backward(input, &grad_after_conv)
332 }
333
334 fn update(&mut self, learning_rate: F) -> Result<()> {
335 self.norm.update(learning_rate)?;
336 self.conv.update(learning_rate)?;
337 Ok(())
338 }
339
340 fn params(&self) -> Vec<Array<F, IxDyn>> {
341 let mut params = Vec::new();
342 params.extend(self.norm.params());
343 params.extend(self.conv.params());
344 params
345 }
346
347 fn set_training(&mut self, training: bool) {
348 self.norm.set_training(training);
349 self.conv.set_training(training);
350 }
351
352 fn is_training(&self) -> bool {
353 self.norm.is_training()
354 }
355
356 fn layer_type(&self) -> &str {
357 "ConvNeXtDownsample"
358 }
359}
360
361#[derive(Debug, Clone)]
363pub struct ConvNeXtStage<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
364 pub downsample: Option<ConvNeXtDownsample<F>>,
366 pub blocks: Vec<ConvNeXtBlock<F>>,
368}
369
370impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtStage<F> {
371 pub fn new(config: &ConvNeXtStageConfig) -> Result<Self> {
373 let downsample = if config.input_channels != config.output_channels || config.stride > 1 {
374 Some(ConvNeXtDownsample::<F>::new(
375 config.input_channels,
376 config.output_channels,
377 config.stride,
378 )?)
379 } else {
380 None
381 };
382
383 let mut blocks = Vec::with_capacity(config.num_blocks);
384 for _ in 0..config.num_blocks {
385 blocks.push(ConvNeXtBlock::<F>::new(
386 config.output_channels,
387 config.layer_scale_init_value,
388 config.drop_path_prob,
389 )?);
390 }
391
392 Ok(Self { downsample, blocks })
393 }
394}
395
396impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
397 for ConvNeXtStage<F>
398{
399 fn as_any(&self) -> &dyn std::any::Any {
400 self
401 }
402
403 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
404 self
405 }
406
407 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
408 let mut x = if let Some(ref ds) = self.downsample {
409 ds.forward(input)?
410 } else {
411 input.clone()
412 };
413 for block in &self.blocks {
414 x = block.forward(&x)?;
415 }
416 Ok(x)
417 }
418
419 fn backward(
420 &self,
421 input: &Array<F, IxDyn>,
422 grad_output: &Array<F, IxDyn>,
423 ) -> Result<Array<F, IxDyn>> {
424 let mut grad = grad_output.clone();
425 for block in self.blocks.iter().rev() {
426 grad = block.backward(&grad, &grad)?;
427 }
428 if let Some(ref ds) = self.downsample {
429 grad = ds.backward(input, &grad)?;
430 }
431 Ok(grad)
432 }
433
434 fn update(&mut self, learning_rate: F) -> Result<()> {
435 if let Some(ref mut ds) = self.downsample {
436 ds.update(learning_rate)?;
437 }
438 for block in &mut self.blocks {
439 block.update(learning_rate)?;
440 }
441 Ok(())
442 }
443
444 fn params(&self) -> Vec<Array<F, IxDyn>> {
445 let mut params = Vec::new();
446 if let Some(ref ds) = self.downsample {
447 params.extend(ds.params());
448 }
449 for block in &self.blocks {
450 params.extend(block.params());
451 }
452 params
453 }
454
455 fn set_training(&mut self, training: bool) {
456 if let Some(ref mut ds) = self.downsample {
457 ds.set_training(training);
458 }
459 for block in &mut self.blocks {
460 block.set_training(training);
461 }
462 }
463
464 fn is_training(&self) -> bool {
465 if let Some(ref ds) = self.downsample {
466 return ds.is_training();
467 }
468 if !self.blocks.is_empty() {
469 return self.blocks[0].is_training();
470 }
471 true
472 }
473
474 fn layer_type(&self) -> &str {
475 "ConvNeXtStage"
476 }
477}
478
479#[derive(Debug)]
481pub struct ConvNeXt<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
482 pub stem: Sequential<F>,
484 pub stages: Vec<ConvNeXtStage<F>>,
486 pub head: Option<Sequential<F>>,
488 pub config: ConvNeXtConfig,
490}
491
492impl<
493 F: Float
494 + Debug
495 + ScalarOperand
496 + Send
497 + Sync
498 + NumAssign
499 + scirs2_core::simd_ops::SimdUnifiedOps
500 + 'static,
501 > ConvNeXt<F>
502{
503 pub fn new(config: ConvNeXtConfig) -> Result<Self> {
505 let mut rng = SmallRng::from_seed([99u8; 32]);
506
507 let mut stem = Sequential::new();
509 stem.add(
510 Conv2D::<F>::new(config.input_channels, config.dims[0], (4, 4), (4, 4), None)
511 .map(|c| c.with_padding(PaddingMode::Custom(0)))?,
512 );
513 stem.add(LayerNorm2D::<F>::new::<SmallRng>(
514 config.dims[0],
515 1e-6,
516 Some("stem_norm"),
517 )?);
518
519 let mut stages = Vec::with_capacity(config.depths.len());
521 let mut current_channels = config.dims[0];
522
523 for (i, &depth) in config.depths.iter().enumerate() {
524 let output_channels = config.dims[i];
525 let stride = if i == 0 { 1 } else { 2 };
526
527 let stage_config = ConvNeXtStageConfig {
528 input_channels: current_channels,
529 output_channels,
530 num_blocks: depth,
531 stride,
532 layer_scale_init_value: config.layer_scale_init_value,
533 drop_path_prob: 0.0,
534 };
535
536 stages.push(ConvNeXtStage::<F>::new(&stage_config)?);
537 current_channels = output_channels;
538 }
539
540 let head = if config.include_top {
542 let last_dim = *config.dims.last().ok_or_else(|| {
543 NeuralError::InvalidArchitecture("ConvNeXt: dims must be non-empty".to_string())
544 })?;
545 let mut head_seq = Sequential::new();
546 head_seq.add(LayerNorm2D::<F>::new::<SmallRng>(
547 last_dim,
548 1e-6,
549 Some("head_norm"),
550 )?);
551 head_seq.add(GlobalAvgPool2D::<F>::new(Some("head_pool")));
553 if let Some(dropout_rate) = config.dropout_rate {
554 if dropout_rate > 0.0 {
555 head_seq.add(Dropout::<F>::new(dropout_rate, &mut rng)?);
556 }
557 }
558 head_seq.add(Dense::<F>::new(
559 last_dim,
560 config.num_classes,
561 Some("classifier"),
562 &mut rng,
563 )?);
564 Some(head_seq)
565 } else {
566 None
567 };
568
569 Ok(Self {
570 stem,
571 stages,
572 head,
573 config,
574 })
575 }
576
577 pub fn convnext_tiny(num_classes: usize, include_top: bool) -> Result<Self> {
579 Self::new(ConvNeXtConfig {
580 variant: ConvNeXtVariant::Tiny,
581 input_channels: 3,
582 depths: vec![3, 3, 9, 3],
583 dims: vec![96, 192, 384, 768],
584 num_classes,
585 dropout_rate: Some(0.1),
586 layer_scale_init_value: 1e-6,
587 include_top,
588 })
589 }
590
591 pub fn convnext_small(num_classes: usize, include_top: bool) -> Result<Self> {
593 Self::new(ConvNeXtConfig {
594 variant: ConvNeXtVariant::Small,
595 input_channels: 3,
596 depths: vec![3, 3, 27, 3],
597 dims: vec![96, 192, 384, 768],
598 num_classes,
599 dropout_rate: Some(0.1),
600 layer_scale_init_value: 1e-6,
601 include_top,
602 })
603 }
604
605 pub fn convnext_base(num_classes: usize, include_top: bool) -> Result<Self> {
607 Self::new(ConvNeXtConfig {
608 variant: ConvNeXtVariant::Base,
609 input_channels: 3,
610 depths: vec![3, 3, 27, 3],
611 dims: vec![128, 256, 512, 1024],
612 num_classes,
613 dropout_rate: Some(0.1),
614 layer_scale_init_value: 1e-6,
615 include_top,
616 })
617 }
618
619 pub fn convnext_large(num_classes: usize, include_top: bool) -> Result<Self> {
621 Self::new(ConvNeXtConfig {
622 variant: ConvNeXtVariant::Large,
623 input_channels: 3,
624 depths: vec![3, 3, 27, 3],
625 dims: vec![192, 384, 768, 1536],
626 num_classes,
627 dropout_rate: Some(0.1),
628 layer_scale_init_value: 1e-6,
629 include_top,
630 })
631 }
632
633 pub fn convnext_xlarge(num_classes: usize, include_top: bool) -> Result<Self> {
635 Self::new(ConvNeXtConfig {
636 variant: ConvNeXtVariant::XLarge,
637 input_channels: 3,
638 depths: vec![3, 3, 27, 3],
639 dims: vec![256, 512, 1024, 2048],
640 num_classes,
641 dropout_rate: Some(0.1),
642 layer_scale_init_value: 1e-6,
643 include_top,
644 })
645 }
646}
647
648impl<
649 F: Float
650 + Debug
651 + ScalarOperand
652 + Send
653 + Sync
654 + NumAssign
655 + scirs2_core::simd_ops::SimdUnifiedOps
656 + 'static,
657 > Layer<F> for ConvNeXt<F>
658{
659 fn as_any(&self) -> &dyn std::any::Any {
660 self
661 }
662
663 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
664 self
665 }
666
667 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
668 let mut x = self.stem.forward(input)?;
669 for stage in &self.stages {
670 x = stage.forward(&x)?;
671 }
672 if let Some(ref head) = self.head {
673 x = head.forward(&x)?;
674 }
675 Ok(x)
676 }
677
678 fn backward(
679 &self,
680 input: &Array<F, IxDyn>,
681 grad_output: &Array<F, IxDyn>,
682 ) -> Result<Array<F, IxDyn>> {
683 let mut grad = grad_output.clone();
684 if let Some(ref head) = self.head {
685 grad = head.backward(&grad, &grad)?;
686 }
687 for stage in self.stages.iter().rev() {
688 grad = stage.backward(&grad, &grad)?;
689 }
690 self.stem.backward(input, &grad)
691 }
692
693 fn update(&mut self, learning_rate: F) -> Result<()> {
694 self.stem.update(learning_rate)?;
695 for stage in &mut self.stages {
696 stage.update(learning_rate)?;
697 }
698 if let Some(ref mut head) = self.head {
699 head.update(learning_rate)?;
700 }
701 Ok(())
702 }
703
704 fn params(&self) -> Vec<Array<F, IxDyn>> {
705 let mut params = Vec::new();
706 params.extend(self.stem.params());
707 for stage in &self.stages {
708 params.extend(stage.params());
709 }
710 if let Some(ref head) = self.head {
711 params.extend(head.params());
712 }
713 params
714 }
715
716 fn set_training(&mut self, training: bool) {
717 self.stem.set_training(training);
718 for stage in &mut self.stages {
719 stage.set_training(training);
720 }
721 if let Some(ref mut head) = self.head {
722 head.set_training(training);
723 }
724 }
725
726 fn is_training(&self) -> bool {
727 self.stem.is_training()
728 }
729
730 fn layer_type(&self) -> &str {
731 "ConvNeXt"
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[test]
740 fn test_convnext_config() {
741 let config = ConvNeXtConfig::default();
742 assert_eq!(config.variant, ConvNeXtVariant::Tiny);
743 assert_eq!(config.input_channels, 3);
744 assert_eq!(config.depths.len(), 4);
745 assert_eq!(config.dims.len(), 4);
746 }
747
748 #[test]
749 fn test_convnext_block_creation() {
750 let block = ConvNeXtBlock::<f64>::new(64, 1e-6, 0.0);
751 assert!(block.is_ok());
752 }
753
754 #[test]
755 fn test_convnext_stage_config() {
756 let config = ConvNeXtStageConfig {
757 input_channels: 64,
758 output_channels: 128,
759 num_blocks: 3,
760 stride: 2,
761 layer_scale_init_value: 1e-6,
762 drop_path_prob: 0.0,
763 };
764 let stage = ConvNeXtStage::<f64>::new(&config);
765 assert!(stage.is_ok());
766 }
767
768 #[test]
769 fn test_convnext_downsample() {
770 let downsample = ConvNeXtDownsample::<f64>::new(64, 128, 2);
771 assert!(downsample.is_ok());
772 }
773
774 #[test]
775 fn test_convnext_variants() {
776 assert_eq!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Tiny);
777 assert_ne!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Base);
778 }
779}