Skip to main content

scirs2_neural/models/architectures/
convnext.rs

1//! ConvNeXt architecture implementation
2//!
3//! This module implements the ConvNeXt architecture as described in
4//! "A ConvNet for the 2020s" (<https://arxiv.org/abs/2201.03545>)
5//! ConvNeXt modernizes ResNet architecture by incorporating design choices from
6//! Vision Transformers, resulting in a pure convolutional model with excellent performance.
7
8use crate::activations::GELU;
9use crate::error::{NeuralError, Result};
10use crate::layers::conv::PaddingMode;
11use crate::layers::{Conv2D, Dense, Dropout, GlobalAvgPool2D, Layer, LayerNorm2D, Sequential};
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, NumAssign};
14use scirs2_core::random::{rngs::SmallRng, SeedableRng};
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// Configuration for a ConvNeXt stage
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ConvNeXtStageConfig {
21    /// Number of input channels
22    pub input_channels: usize,
23    /// Number of output channels
24    pub output_channels: usize,
25    /// Number of blocks in this stage
26    pub num_blocks: usize,
27    /// Stride for the first block (typically 2 for downsampling, 1 otherwise)
28    pub stride: usize,
29    /// Layer scale initialization value (typically 1e-6)
30    pub layer_scale_init_value: f64,
31    /// Dropout probability
32    pub drop_path_prob: f64,
33}
34
35/// Configuration for a ConvNeXt model
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ConvNeXtConfig {
38    /// Model depth variant (Tiny, Small, Base, Large, XLarge)
39    pub variant: ConvNeXtVariant,
40    /// Number of input channels (typically 3 for RGB images)
41    pub input_channels: usize,
42    /// Depths for each stage
43    pub depths: Vec<usize>,
44    /// Dimensions (channels) for each stage
45    pub dims: Vec<usize>,
46    /// Number of output classes
47    pub num_classes: usize,
48    /// Dropout rate
49    pub dropout_rate: Option<f64>,
50    /// Layer scale initialization value
51    pub layer_scale_init_value: f64,
52    /// Whether to include the classification head
53    pub include_top: bool,
54}
55
56impl Default for ConvNeXtConfig {
57    fn default() -> Self {
58        Self {
59            variant: ConvNeXtVariant::Tiny,
60            input_channels: 3,
61            depths: vec![3, 3, 9, 3],
62            dims: vec![96, 192, 384, 768],
63            num_classes: 1000,
64            dropout_rate: Some(0.0),
65            layer_scale_init_value: 1e-6,
66            include_top: true,
67        }
68    }
69}
70
71/// ConvNeXt model variants
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum ConvNeXtVariant {
74    /// ConvNeXt-Tiny
75    Tiny,
76    /// ConvNeXt-Small
77    Small,
78    /// ConvNeXt-Base
79    Base,
80    /// ConvNeXt-Large
81    Large,
82    /// ConvNeXt-XLarge
83    XLarge,
84}
85
86/// ConvNeXt residual block.
87///
88/// Each block applies: depthwise 7×7 conv → LayerNorm2D → 1×1 conv (×4 channels) →
89/// GELU → 1×1 conv (back to original channels) → layer scale → skip connection.
90#[derive(Debug, Clone)]
91pub struct ConvNeXtBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
92    /// Depthwise convolution (7×7, same-padding)
93    pub depthwise_conv: Conv2D<F>,
94    /// Layer normalization over spatial dims
95    pub norm: LayerNorm2D<F>,
96    /// Pointwise convolution 1 (channels → channels×4)
97    pub pointwise_conv1: Conv2D<F>,
98    /// GELU activation
99    pub gelu: GELU,
100    /// Pointwise convolution 2 (channels×4 → channels)
101    pub pointwise_conv2: Conv2D<F>,
102    /// Layer scale gamma parameter, shape `[channels]`
103    pub gamma: Array<F, IxDyn>,
104    /// Whether to apply stochastic-depth scaling
105    pub use_skip: bool,
106    /// Scale factor for stochastic depth: `1 - drop_path_prob`
107    pub skip_scale: F,
108}
109
110impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtBlock<F> {
111    /// Create a new ConvNeXtBlock.
112    pub fn new(channels: usize, layer_scale_init_value: f64, drop_path_prob: f64) -> Result<Self> {
113        let depthwise_conv = Conv2D::<F>::new(channels, channels, (7, 7), (1, 1), None)
114            .map(|c| c.with_padding(PaddingMode::Custom(3)))?;
115
116        let norm = LayerNorm2D::<F>::new::<SmallRng>(channels, 1e-6, Some("norm"))?;
117
118        let pointwise_conv1 = Conv2D::<F>::new(channels, channels * 4, (1, 1), (1, 1), None)
119            .map(|c| c.with_padding(PaddingMode::Custom(0)))?;
120
121        let gelu = GELU::new();
122
123        let pointwise_conv2 = Conv2D::<F>::new(channels * 4, channels, (1, 1), (1, 1), None)
124            .map(|c| c.with_padding(PaddingMode::Custom(0)))?;
125
126        let gamma_value = F::from(layer_scale_init_value).ok_or_else(|| {
127            NeuralError::InvalidArchitecture(
128                "ConvNeXtBlock: failed to convert layer_scale_init_value to float".to_string(),
129            )
130        })?;
131        let gamma = Array::<F, _>::from_elem(IxDyn(&[channels]), gamma_value);
132
133        let skip_scale = F::from(1.0 - drop_path_prob).ok_or_else(|| {
134            NeuralError::InvalidArchitecture(
135                "ConvNeXtBlock: failed to convert drop_path_prob to float".to_string(),
136            )
137        })?;
138        let use_skip = drop_path_prob > 0.0;
139
140        Ok(Self {
141            depthwise_conv,
142            norm,
143            pointwise_conv1,
144            gelu,
145            pointwise_conv2,
146            gamma,
147            use_skip,
148            skip_scale,
149        })
150    }
151}
152
153impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
154    for ConvNeXtBlock<F>
155{
156    fn as_any(&self) -> &dyn std::any::Any {
157        self
158    }
159
160    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
161        self
162    }
163
164    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
165        let identity = input.clone();
166
167        // Depthwise conv → LayerNorm2D
168        let mut x = self.depthwise_conv.forward(input)?;
169        x = self.norm.forward(&x)?;
170
171        // Pointwise expand → GELU → pointwise project
172        x = self.pointwise_conv1.forward(&x)?;
173        x = <GELU as Layer<F>>::forward(&self.gelu, &x)?;
174        x = self.pointwise_conv2.forward(&x)?;
175
176        // Apply layer scale: broadcast gamma [C] over [N,C,H,W]
177        let shape = x.shape().to_vec();
178        if shape.len() == 4 {
179            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
180            for ni in 0..n {
181                for ci in 0..c {
182                    let g = self.gamma[ci];
183                    for hi in 0..h {
184                        for wi in 0..w {
185                            x[[ni, ci, hi, wi]] *= g;
186                        }
187                    }
188                }
189            }
190        }
191
192        // Stochastic depth
193        if self.use_skip {
194            x *= self.skip_scale;
195        }
196
197        Ok(x + identity)
198    }
199
200    fn backward(
201        &self,
202        input: &Array<F, IxDyn>,
203        grad_output: &Array<F, IxDyn>,
204    ) -> Result<Array<F, IxDyn>> {
205        let mut grad = grad_output.clone();
206        let grad_skip = grad.clone();
207
208        if self.use_skip {
209            grad *= self.skip_scale;
210        }
211
212        // Undo layer scale
213        let shape = grad.shape().to_vec();
214        if shape.len() == 4 {
215            let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
216            for ni in 0..n {
217                for ci in 0..c {
218                    let g = self.gamma[ci];
219                    for hi in 0..h {
220                        for wi in 0..w {
221                            grad[[ni, ci, hi, wi]] *= g;
222                        }
223                    }
224                }
225            }
226        }
227
228        let grad_after_conv2 = self.pointwise_conv2.backward(&grad, &grad)?;
229        let grad_after_gelu = grad_after_conv2.clone();
230        let grad_after_conv1 = self
231            .pointwise_conv1
232            .backward(&grad_after_gelu, &grad_after_gelu)?;
233        let grad_after_norm = self.norm.backward(&grad_after_conv1, &grad_after_conv1)?;
234        let grad_after_dwconv = self.depthwise_conv.backward(input, &grad_after_norm)?;
235
236        Ok(grad_after_dwconv + grad_skip)
237    }
238
239    fn update(&mut self, learning_rate: F) -> Result<()> {
240        self.depthwise_conv.update(learning_rate)?;
241        self.norm.update(learning_rate)?;
242        self.pointwise_conv1.update(learning_rate)?;
243        self.pointwise_conv2.update(learning_rate)?;
244
245        // Small gradient-like update on gamma (simplified)
246        let small_update = F::from(0.0001_f64).ok_or_else(|| {
247            NeuralError::InvalidArchitecture(
248                "ConvNeXtBlock: failed to convert small_update to float".to_string(),
249            )
250        })? * learning_rate;
251        for elem in self.gamma.iter_mut() {
252            *elem -= small_update;
253        }
254        Ok(())
255    }
256
257    fn params(&self) -> Vec<Array<F, IxDyn>> {
258        let mut params = Vec::new();
259        params.extend(self.depthwise_conv.params());
260        params.extend(self.norm.params());
261        params.extend(self.pointwise_conv1.params());
262        params.extend(self.pointwise_conv2.params());
263        params.push(self.gamma.clone());
264        params
265    }
266
267    fn set_training(&mut self, training: bool) {
268        self.depthwise_conv.set_training(training);
269        self.norm.set_training(training);
270        self.pointwise_conv1.set_training(training);
271        self.pointwise_conv2.set_training(training);
272        <GELU as Layer<F>>::set_training(&mut self.gelu, training);
273    }
274
275    fn is_training(&self) -> bool {
276        self.depthwise_conv.is_training()
277    }
278
279    fn layer_type(&self) -> &str {
280        "ConvNeXtBlock"
281    }
282}
283
284/// ConvNeXt downsampling layer: LayerNorm2D followed by a strided convolution.
285#[derive(Debug, Clone)]
286pub struct ConvNeXtDownsample<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
287    /// Layer normalization before convolution
288    pub norm: LayerNorm2D<F>,
289    /// Strided convolution for spatial downsampling
290    pub conv: Conv2D<F>,
291}
292
293impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtDownsample<F> {
294    /// Create a new ConvNeXtDownsample.
295    pub fn new(in_channels: usize, out_channels: usize, stride: usize) -> Result<Self> {
296        let norm = LayerNorm2D::<F>::new::<SmallRng>(in_channels, 1e-6, Some("downsample_norm"))?;
297        let conv = Conv2D::<F>::new(
298            in_channels,
299            out_channels,
300            (stride, stride),
301            (stride, stride),
302            None,
303        )
304        .map(|c| c.with_padding(PaddingMode::Custom(0)))?;
305        Ok(Self { norm, conv })
306    }
307}
308
309impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
310    for ConvNeXtDownsample<F>
311{
312    fn as_any(&self) -> &dyn std::any::Any {
313        self
314    }
315
316    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
317        self
318    }
319
320    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
321        let x = self.norm.forward(input)?;
322        self.conv.forward(&x)
323    }
324
325    fn backward(
326        &self,
327        input: &Array<F, IxDyn>,
328        grad_output: &Array<F, IxDyn>,
329    ) -> Result<Array<F, IxDyn>> {
330        let grad_after_conv = self.conv.backward(grad_output, grad_output)?;
331        self.norm.backward(input, &grad_after_conv)
332    }
333
334    fn update(&mut self, learning_rate: F) -> Result<()> {
335        self.norm.update(learning_rate)?;
336        self.conv.update(learning_rate)?;
337        Ok(())
338    }
339
340    fn params(&self) -> Vec<Array<F, IxDyn>> {
341        let mut params = Vec::new();
342        params.extend(self.norm.params());
343        params.extend(self.conv.params());
344        params
345    }
346
347    fn set_training(&mut self, training: bool) {
348        self.norm.set_training(training);
349        self.conv.set_training(training);
350    }
351
352    fn is_training(&self) -> bool {
353        self.norm.is_training()
354    }
355
356    fn layer_type(&self) -> &str {
357        "ConvNeXtDownsample"
358    }
359}
360
361/// A single ConvNeXt stage: optional downsampling layer followed by ConvNeXt blocks.
362#[derive(Debug, Clone)]
363pub struct ConvNeXtStage<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
364    /// Optional downsampling layer (present when channels change or stride > 1)
365    pub downsample: Option<ConvNeXtDownsample<F>>,
366    /// ConvNeXt residual blocks
367    pub blocks: Vec<ConvNeXtBlock<F>>,
368}
369
370impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtStage<F> {
371    /// Create a new ConvNeXtStage.
372    pub fn new(config: &ConvNeXtStageConfig) -> Result<Self> {
373        let downsample = if config.input_channels != config.output_channels || config.stride > 1 {
374            Some(ConvNeXtDownsample::<F>::new(
375                config.input_channels,
376                config.output_channels,
377                config.stride,
378            )?)
379        } else {
380            None
381        };
382
383        let mut blocks = Vec::with_capacity(config.num_blocks);
384        for _ in 0..config.num_blocks {
385            blocks.push(ConvNeXtBlock::<F>::new(
386                config.output_channels,
387                config.layer_scale_init_value,
388                config.drop_path_prob,
389            )?);
390        }
391
392        Ok(Self { downsample, blocks })
393    }
394}
395
396impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
397    for ConvNeXtStage<F>
398{
399    fn as_any(&self) -> &dyn std::any::Any {
400        self
401    }
402
403    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
404        self
405    }
406
407    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
408        let mut x = if let Some(ref ds) = self.downsample {
409            ds.forward(input)?
410        } else {
411            input.clone()
412        };
413        for block in &self.blocks {
414            x = block.forward(&x)?;
415        }
416        Ok(x)
417    }
418
419    fn backward(
420        &self,
421        input: &Array<F, IxDyn>,
422        grad_output: &Array<F, IxDyn>,
423    ) -> Result<Array<F, IxDyn>> {
424        let mut grad = grad_output.clone();
425        for block in self.blocks.iter().rev() {
426            grad = block.backward(&grad, &grad)?;
427        }
428        if let Some(ref ds) = self.downsample {
429            grad = ds.backward(input, &grad)?;
430        }
431        Ok(grad)
432    }
433
434    fn update(&mut self, learning_rate: F) -> Result<()> {
435        if let Some(ref mut ds) = self.downsample {
436            ds.update(learning_rate)?;
437        }
438        for block in &mut self.blocks {
439            block.update(learning_rate)?;
440        }
441        Ok(())
442    }
443
444    fn params(&self) -> Vec<Array<F, IxDyn>> {
445        let mut params = Vec::new();
446        if let Some(ref ds) = self.downsample {
447            params.extend(ds.params());
448        }
449        for block in &self.blocks {
450            params.extend(block.params());
451        }
452        params
453    }
454
455    fn set_training(&mut self, training: bool) {
456        if let Some(ref mut ds) = self.downsample {
457            ds.set_training(training);
458        }
459        for block in &mut self.blocks {
460            block.set_training(training);
461        }
462    }
463
464    fn is_training(&self) -> bool {
465        if let Some(ref ds) = self.downsample {
466            return ds.is_training();
467        }
468        if !self.blocks.is_empty() {
469            return self.blocks[0].is_training();
470        }
471        true
472    }
473
474    fn layer_type(&self) -> &str {
475        "ConvNeXtStage"
476    }
477}
478
479/// Full ConvNeXt model: stem → stages → optional classification head.
480#[derive(Debug)]
481pub struct ConvNeXt<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
482    /// Stem layer (4×4 strided conv + LayerNorm2D)
483    pub stem: Sequential<F>,
484    /// Main stages of the network
485    pub stages: Vec<ConvNeXtStage<F>>,
486    /// Classification head (if include_top is true)
487    pub head: Option<Sequential<F>>,
488    /// Model configuration
489    pub config: ConvNeXtConfig,
490}
491
492impl<
493        F: Float
494            + Debug
495            + ScalarOperand
496            + Send
497            + Sync
498            + NumAssign
499            + scirs2_core::simd_ops::SimdUnifiedOps
500            + 'static,
501    > ConvNeXt<F>
502{
503    /// Create a new ConvNeXt model from config.
504    pub fn new(config: ConvNeXtConfig) -> Result<Self> {
505        let mut rng = SmallRng::from_seed([99u8; 32]);
506
507        // Stem: 4×4 strided conv + LayerNorm2D
508        let mut stem = Sequential::new();
509        stem.add(
510            Conv2D::<F>::new(config.input_channels, config.dims[0], (4, 4), (4, 4), None)
511                .map(|c| c.with_padding(PaddingMode::Custom(0)))?,
512        );
513        stem.add(LayerNorm2D::<F>::new::<SmallRng>(
514            config.dims[0],
515            1e-6,
516            Some("stem_norm"),
517        )?);
518
519        // Stages
520        let mut stages = Vec::with_capacity(config.depths.len());
521        let mut current_channels = config.dims[0];
522
523        for (i, &depth) in config.depths.iter().enumerate() {
524            let output_channels = config.dims[i];
525            let stride = if i == 0 { 1 } else { 2 };
526
527            let stage_config = ConvNeXtStageConfig {
528                input_channels: current_channels,
529                output_channels,
530                num_blocks: depth,
531                stride,
532                layer_scale_init_value: config.layer_scale_init_value,
533                drop_path_prob: 0.0,
534            };
535
536            stages.push(ConvNeXtStage::<F>::new(&stage_config)?);
537            current_channels = output_channels;
538        }
539
540        // Head
541        let head = if config.include_top {
542            let last_dim = *config.dims.last().ok_or_else(|| {
543                NeuralError::InvalidArchitecture("ConvNeXt: dims must be non-empty".to_string())
544            })?;
545            let mut head_seq = Sequential::new();
546            head_seq.add(LayerNorm2D::<F>::new::<SmallRng>(
547                last_dim,
548                1e-6,
549                Some("head_norm"),
550            )?);
551            // GlobalAvgPool2D::new returns Self (not Result), so no `?`
552            head_seq.add(GlobalAvgPool2D::<F>::new(Some("head_pool")));
553            if let Some(dropout_rate) = config.dropout_rate {
554                if dropout_rate > 0.0 {
555                    head_seq.add(Dropout::<F>::new(dropout_rate, &mut rng)?);
556                }
557            }
558            head_seq.add(Dense::<F>::new(
559                last_dim,
560                config.num_classes,
561                Some("classifier"),
562                &mut rng,
563            )?);
564            Some(head_seq)
565        } else {
566            None
567        };
568
569        Ok(Self {
570            stem,
571            stages,
572            head,
573            config,
574        })
575    }
576
577    /// Create a ConvNeXt-Tiny model.
578    pub fn convnext_tiny(num_classes: usize, include_top: bool) -> Result<Self> {
579        Self::new(ConvNeXtConfig {
580            variant: ConvNeXtVariant::Tiny,
581            input_channels: 3,
582            depths: vec![3, 3, 9, 3],
583            dims: vec![96, 192, 384, 768],
584            num_classes,
585            dropout_rate: Some(0.1),
586            layer_scale_init_value: 1e-6,
587            include_top,
588        })
589    }
590
591    /// Create a ConvNeXt-Small model.
592    pub fn convnext_small(num_classes: usize, include_top: bool) -> Result<Self> {
593        Self::new(ConvNeXtConfig {
594            variant: ConvNeXtVariant::Small,
595            input_channels: 3,
596            depths: vec![3, 3, 27, 3],
597            dims: vec![96, 192, 384, 768],
598            num_classes,
599            dropout_rate: Some(0.1),
600            layer_scale_init_value: 1e-6,
601            include_top,
602        })
603    }
604
605    /// Create a ConvNeXt-Base model.
606    pub fn convnext_base(num_classes: usize, include_top: bool) -> Result<Self> {
607        Self::new(ConvNeXtConfig {
608            variant: ConvNeXtVariant::Base,
609            input_channels: 3,
610            depths: vec![3, 3, 27, 3],
611            dims: vec![128, 256, 512, 1024],
612            num_classes,
613            dropout_rate: Some(0.1),
614            layer_scale_init_value: 1e-6,
615            include_top,
616        })
617    }
618
619    /// Create a ConvNeXt-Large model.
620    pub fn convnext_large(num_classes: usize, include_top: bool) -> Result<Self> {
621        Self::new(ConvNeXtConfig {
622            variant: ConvNeXtVariant::Large,
623            input_channels: 3,
624            depths: vec![3, 3, 27, 3],
625            dims: vec![192, 384, 768, 1536],
626            num_classes,
627            dropout_rate: Some(0.1),
628            layer_scale_init_value: 1e-6,
629            include_top,
630        })
631    }
632
633    /// Create a ConvNeXt-XLarge model.
634    pub fn convnext_xlarge(num_classes: usize, include_top: bool) -> Result<Self> {
635        Self::new(ConvNeXtConfig {
636            variant: ConvNeXtVariant::XLarge,
637            input_channels: 3,
638            depths: vec![3, 3, 27, 3],
639            dims: vec![256, 512, 1024, 2048],
640            num_classes,
641            dropout_rate: Some(0.1),
642            layer_scale_init_value: 1e-6,
643            include_top,
644        })
645    }
646}
647
648impl<
649        F: Float
650            + Debug
651            + ScalarOperand
652            + Send
653            + Sync
654            + NumAssign
655            + scirs2_core::simd_ops::SimdUnifiedOps
656            + 'static,
657    > Layer<F> for ConvNeXt<F>
658{
659    fn as_any(&self) -> &dyn std::any::Any {
660        self
661    }
662
663    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
664        self
665    }
666
667    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
668        let mut x = self.stem.forward(input)?;
669        for stage in &self.stages {
670            x = stage.forward(&x)?;
671        }
672        if let Some(ref head) = self.head {
673            x = head.forward(&x)?;
674        }
675        Ok(x)
676    }
677
678    fn backward(
679        &self,
680        input: &Array<F, IxDyn>,
681        grad_output: &Array<F, IxDyn>,
682    ) -> Result<Array<F, IxDyn>> {
683        let mut grad = grad_output.clone();
684        if let Some(ref head) = self.head {
685            grad = head.backward(&grad, &grad)?;
686        }
687        for stage in self.stages.iter().rev() {
688            grad = stage.backward(&grad, &grad)?;
689        }
690        self.stem.backward(input, &grad)
691    }
692
693    fn update(&mut self, learning_rate: F) -> Result<()> {
694        self.stem.update(learning_rate)?;
695        for stage in &mut self.stages {
696            stage.update(learning_rate)?;
697        }
698        if let Some(ref mut head) = self.head {
699            head.update(learning_rate)?;
700        }
701        Ok(())
702    }
703
704    fn params(&self) -> Vec<Array<F, IxDyn>> {
705        let mut params = Vec::new();
706        params.extend(self.stem.params());
707        for stage in &self.stages {
708            params.extend(stage.params());
709        }
710        if let Some(ref head) = self.head {
711            params.extend(head.params());
712        }
713        params
714    }
715
716    fn set_training(&mut self, training: bool) {
717        self.stem.set_training(training);
718        for stage in &mut self.stages {
719            stage.set_training(training);
720        }
721        if let Some(ref mut head) = self.head {
722            head.set_training(training);
723        }
724    }
725
726    fn is_training(&self) -> bool {
727        self.stem.is_training()
728    }
729
730    fn layer_type(&self) -> &str {
731        "ConvNeXt"
732    }
733}
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738
739    #[test]
740    fn test_convnext_config() {
741        let config = ConvNeXtConfig::default();
742        assert_eq!(config.variant, ConvNeXtVariant::Tiny);
743        assert_eq!(config.input_channels, 3);
744        assert_eq!(config.depths.len(), 4);
745        assert_eq!(config.dims.len(), 4);
746    }
747
748    #[test]
749    fn test_convnext_block_creation() {
750        let block = ConvNeXtBlock::<f64>::new(64, 1e-6, 0.0);
751        assert!(block.is_ok());
752    }
753
754    #[test]
755    fn test_convnext_stage_config() {
756        let config = ConvNeXtStageConfig {
757            input_channels: 64,
758            output_channels: 128,
759            num_blocks: 3,
760            stride: 2,
761            layer_scale_init_value: 1e-6,
762            drop_path_prob: 0.0,
763        };
764        let stage = ConvNeXtStage::<f64>::new(&config);
765        assert!(stage.is_ok());
766    }
767
768    #[test]
769    fn test_convnext_downsample() {
770        let downsample = ConvNeXtDownsample::<f64>::new(64, 128, 2);
771        assert!(downsample.is_ok());
772    }
773
774    #[test]
775    fn test_convnext_variants() {
776        assert_eq!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Tiny);
777        assert_ne!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Base);
778    }
779}