Skip to main content

scirs2_neural/models/architectures/
convnext.rs

1//! ConvNeXt architecture implementation
2//!
3//! This module implements the ConvNeXt architecture as described in
4//! "A ConvNet for the 2020s" (<https://arxiv.org/abs/2201.03545>)
5//! ConvNeXt modernizes ResNet architecture by incorporating design choices from
6//! Vision Transformers, resulting in a pure convolutional model with excellent performance.
7
8use crate::activations::GELU;
9use crate::error::Result;
10use crate::layers::conv::PaddingMode;
11use crate::layers::{Conv2D, Dense, Dropout, GlobalAvgPool2D, Layer, Sequential};
12// Note: LayerNorm2D not yet implemented, using LayerNorm instead
13use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
14use scirs2_core::numeric::{Float, NumAssign};
15use scirs2_core::random::{rngs::SmallRng, SeedableRng};
16use serde::{Deserialize, Serialize};
17use std::fmt::Debug;
18
19/// Configuration for a ConvNeXt stage
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ConvNeXtStageConfig {
22    /// Number of input channels
23    pub input_channels: usize,
24    /// Number of output channels
25    pub output_channels: usize,
26    /// Number of blocks in this stage
27    pub num_blocks: usize,
28    /// Stride for the first block (typically 2 for downsampling, 1 otherwise)
29    pub stride: usize,
30    /// Layer scale initialization value (typically 1e-6)
31    pub layer_scale_init_value: f64,
32    /// Dropout probability
33    pub drop_path_prob: f64,
34}
35
36/// Configuration for a ConvNeXt model
37#[derive(Debug, Clone)]
38pub struct ConvNeXtConfig {
39    /// Model depth variant (Tiny, Small, Base, Large, XLarge)
40    pub variant: ConvNeXtVariant,
41    /// Number of input channels (typically 3 for RGB images)
42    pub input_channels: usize,
43    /// Depths for each stage
44    pub depths: Vec<usize>,
45    /// Dimensions (channels) for each stage
46    pub dims: Vec<usize>,
47    /// Number of output classes
48    pub num_classes: usize,
49    /// Dropout rate
50    pub dropout_rate: Option<f64>,
51    /// Layer scale initialization value
52    pub layer_scale_init_value: f64,
53    /// Whether to include the classification head
54    pub include_top: bool,
55}
56
57impl Default for ConvNeXtConfig {
58    fn default() -> Self {
59        Self {
60            variant: ConvNeXtVariant::Tiny,
61            input_channels: 3,
62            depths: vec![3, 3, 9, 3],
63            dims: vec![96, 192, 384, 768],
64            num_classes: 1000,
65            dropout_rate: Some(0.0),
66            layer_scale_init_value: 1e-6,
67            include_top: true,
68        }
69    }
70}
71
72/// ConvNeXt model variants
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
74pub enum ConvNeXtVariant {
75    /// ConvNeXt-Tiny
76    Tiny,
77    /// ConvNeXt-Small
78    Small,
79    /// ConvNeXt-Base
80    Base,
81    /// ConvNeXt-Large
82    Large,
83    /// ConvNeXt-XLarge
84    XLarge,
85}
86
87// ConvNeXt block implementation
88// TODO: Re-enable once LayerNorm2D is implemented
89// #[derive(Debug, Clone)]
90// pub struct ConvNeXtBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
91//     /// Depthwise convolution
92//     pub depthwise_conv: Conv2D<F>,
93//     /// Layer normalization
94//     pub norm: LayerNorm2D<F>,
95//     /// Pointwise convolution 1
96//     pub pointwise_conv1: Conv2D<F>,
97//     /// GELU activation
98//     pub gelu: GELU,
99//     /// Pointwise convolution 2
100//     pub pointwise_conv2: Conv2D<F>,
101//     /// Layer scale gamma parameter
102//     pub gamma: Array<F, IxDyn>,
103//     /// Skip connection flag
104//     pub use_skip: bool,
105//     /// Skip connection scale for stochastic depth
106//     pub skip_scale: F,
107// }
108
109// TODO: Re-enable once LayerNorm2D is implemented
110// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtBlock<F> {
111//     /// Create a new ConvNeXtBlock
112//     pub fn new(channels: usize, layer_scale_init_value: f64, drop_path_prob: f64) -> Result<Self> {
113//         let mut rng = scirs2_core::random::rng();
114//
115//         let depthwise_conv = Conv2D::<F>::new(
116//             channels,
117//             channels,
118//             (7, 7),
119//             (1, 1),
120//             None,
121//         )?.with_padding(PaddingMode::Custom(3));
122//
123//         let norm = LayerNorm2D::<F>::new::<SmallRng>(channels, 1e-6, Some("norm"))?;
124//
125//         let pointwise_conv1 = Conv2D::<F>::new(
126//             channels,
127//             channels * 4,
128//             (1, 1),
129//             (1, 1),
130//             None,
131//         )?.with_padding(PaddingMode::Custom(0));
132//
133//         let gelu = GELU::new();
134//
135//         let pointwise_conv2 = Conv2D::<F>::new(
136//             channels * 4,
137//             channels,
138//             (1, 1),
139//             (1, 1),
140//             None,
141//         )?.with_padding(PaddingMode::Custom(0));
142//
143//         // Initialize gamma as a learnable parameter
144//         let gamma_value = F::from(layer_scale_init_value).expect("Failed to convert to float");
145//         let gamma = Array::<F, _>::from_elem([channels, 1, 1], gamma_value).into_dyn();
146//
147//         // Stochastic depth rate
148//         let skip_scale = F::from(1.0 - drop_path_prob).expect("Failed to convert to float");
149//         let use_skip = drop_path_prob > 0.0;
150//
151//         Ok(Self {
152//             depthwise_conv,
153//             norm,
154//             pointwise_conv1,
155//             gelu,
156//             pointwise_conv2,
157//             gamma,
158//             use_skip,
159//             skip_scale,
160//         })
161//     }
162// }
163
164// TODO: Re-enable once LayerNorm2D is implemented
165// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXtBlock<F> {
166//     fn as_any(&self) -> &dyn std::any::Any {
167//         self
168//     }
169//
170//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
171//         self
172//     }
173//
174//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
175//         // Save input for skip connection
176//         let identity = input.clone();
177//
178//         // Depthwise convolution
179//         let mut x = self.depthwise_conv.forward(input)?;
180//
181//         // Normalization
182//         x = self.norm.forward(&x)?;
183//
184//         // First pointwise convolution and activation
185//         x = self.pointwise_conv1.forward(&x)?;
186//         x = <GELU as Layer<F>>::forward(&self.gelu, &x)?;
187//
188//         // Second pointwise convolution
189//         x = self.pointwise_conv2.forward(&x)?;
190//
191//         // Apply layer scale
192//         let shape = x.shape().to_vec();
193//         if shape.len() >= 4 {
194//             let view = x
195//                 .clone()
196//                 .into_shape_with_order((shape[0], shape[1], shape[2] * shape[3]))?;
197//             let scaled = view * &self.gamma;
198//             x = scaled.into_shape_with_order(shape).expect("Operation failed");
199//         }
200//
201//         // Apply stochastic depth and skip connection
202//         if self.use_skip {
203//             // During training, scale the output by (1 - drop_path_prob)
204//             x = x * self.skip_scale;
205//         }
206//
207//         // Add skip connection
208//         x = x + identity;
209//
210//         Ok(x)
211//     }
212//
213//     fn backward(
214//         &self,
215//         input: &Array<F, IxDyn>,
216//         grad_output: &Array<F, IxDyn>,
217//     ) -> Result<Array<F, IxDyn>> {
218//         // ConvNeXt backward pass in reverse order of forward pass
219//         let mut grad = grad_output.clone();
220//
221//         // Gradient through skip connection
222//         let grad_skip = grad.clone();
223//
224//         // Gradient through stochastic depth scaling
225//         if self.use_skip {
226//             grad = grad * self.skip_scale;
227//         }
228//
229//         // Gradient through layer scale
230//         let shape = grad.shape().to_vec();
231//         if shape.len() >= 4 {
232//             let grad_view = grad
233//                 .clone()
234//                 .into_shape_with_order((shape[0], shape[1], shape[2] * shape[3]))?;
235//             let grad_scaled = grad_view * &self.gamma;
236//             grad = grad_scaled.into_shape_with_order(shape).expect("Operation failed");
237//         }
238//
239//         // Backward through second pointwise convolution
240//         let grad_after_conv2 = self.pointwise_conv2.backward(&grad, &grad)?;
241//
242//         // Backward through GELU activation (simplified)
243//         let grad_after_gelu = grad_after_conv2.clone();
244//
245//         // Backward through first pointwise convolution
246//         let grad_after_conv1 = self
247//             .pointwise_conv1
248//             .backward(&grad_after_gelu, &grad_after_gelu)?;
249//
250//         // Backward through normalization
251//         let grad_after_norm = self.norm.backward(&grad_after_conv1, &grad_after_conv1)?;
252//
253//         // Backward through depthwise convolution
254//         let grad_after_dwconv = self.depthwise_conv.backward(input, &grad_after_norm)?;
255//
256//         // Combine gradient from main path and skip connection
257//         let grad_input = grad_after_dwconv + grad_skip;
258//
259//         Ok(grad_input)
260//     }
261//
262//     fn update(&mut self, learning_rate: F) -> Result<()> {
263//         self.depthwise_conv.update(learning_rate)?;
264//         self.norm.update(learning_rate)?;
265//         self.pointwise_conv1.update(learning_rate)?;
266//         self.pointwise_conv2.update(learning_rate)?;
267//
268//         // Update gamma parameter
269//         let small_update = F::from(0.0001).expect("Failed to convert constant to float") * learning_rate;
270//         for elem in self.gamma.iter_mut() {
271//             *elem = *elem - small_update;
272//         }
273//
274//         Ok(())
275//     }
276//
277//     fn params(&self) -> Vec<Array<F, IxDyn>> {
278//         let mut params = Vec::new();
279//         params.extend(self.depthwise_conv.params());
280//         params.extend(self.norm.params());
281//         params.extend(self.pointwise_conv1.params());
282//         params.extend(self.pointwise_conv2.params());
283//         params.push(self.gamma.clone());
284//         params
285//     }
286//
287//     fn set_training(&mut self, training: bool) {
288//         self.depthwise_conv.set_training(training);
289//         self.norm.set_training(training);
290//         self.pointwise_conv1.set_training(training);
291//         self.pointwise_conv2.set_training(training);
292//         <GELU as Layer<F>>::set_training(&mut self.gelu, training);
293//     }
294//
295//     fn is_training(&self) -> bool {
296//         self.depthwise_conv.is_training()
297//     }
298// }
299
300// ConvNeXt downsampling layer
301// TODO: Re-enable once LayerNorm2D is implemented
302// #[derive(Debug, Clone)]
303// pub struct ConvNeXtDownsample<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
304//     /// Layer normalization before convolution
305//     pub norm: LayerNorm2D<F>,
306//     /// Convolution for downsampling
307//     pub conv: Conv2D<F>,
308// }
309//
310// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtDownsample<F> {
311//     /// Create a new ConvNeXtDownsample
312//     pub fn new(in_channels: usize, out_channels: usize, stride: usize) -> Result<Self> {
313//         let norm = LayerNorm2D::<F>::new::<SmallRng>(in_channels, 1e-6, Some("downsample_norm"))?;
314//
315//         let mut rng = scirs2_core::random::rng();
316//         let conv = Conv2D::<F>::new(
317//             in_channels,
318//             out_channels,
319//             (stride, stride),
320//             (stride, stride),
321//             None,
322//         )?.with_padding(PaddingMode::Custom(0));
323//
324//         Ok(Self { norm, conv })
325//     }
326// }
327//
328// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXtDownsample<F> {
329//     fn as_any(&self) -> &dyn std::any::Any {
330//         self
331//     }
332//
333//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334//         self
335//     }
336//
337//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
338//         let x = self.norm.forward(input)?;
339//         self.conv.forward(&x)
340//     }
341//
342//     fn backward(
343//         &self,
344//         input: &Array<F, IxDyn>,
345//         grad_output: &Array<F, IxDyn>,
346//     ) -> Result<Array<F, IxDyn>> {
347//         let grad_after_conv = self.conv.backward(grad_output, grad_output)?;
348//         self.norm.backward(input, &grad_after_conv)
349//     }
350//
351//     fn update(&mut self, learning_rate: F) -> Result<()> {
352//         self.norm.update(learning_rate)?;
353//         self.conv.update(learning_rate)?;
354//         Ok(())
355//     }
356//
357//     fn params(&self) -> Vec<Array<F, IxDyn>> {
358//         let mut params = Vec::new();
359//         params.extend(self.norm.params());
360//         params.extend(self.conv.params());
361//         params
362//     }
363//
364//     fn set_training(&mut self, training: bool) {
365//         self.norm.set_training(training);
366//         self.conv.set_training(training);
367//     }
368//
369//     fn is_training(&self) -> bool {
370//         self.norm.is_training()
371//     }
372// }
373
374// ConvNeXt stage
375// TODO: Re-enable once LayerNorm2D is implemented
376// #[derive(Debug, Clone)]
377// pub struct ConvNeXtStage<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
378//     /// Downsampling layer (optional)
379//     pub downsample: Option<ConvNeXtDownsample<F>>,
380//     /// Blocks in this stage
381//     pub blocks: Vec<ConvNeXtBlock<F>>,
382// }
383//
384// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtStage<F> {
385//     /// Create a new ConvNeXtStage
386//     pub fn new(config: &ConvNeXtStageConfig) -> Result<Self> {
387//         // Create the downsampling layer if needed
388//         let downsample = if config.input_channels != config.output_channels || config.stride > 1 {
389//             Some(ConvNeXtDownsample::<F>::new(
390//                 config.input_channels,
391//                 config.output_channels,
392//                 config.stride,
393//             )?)
394//         } else {
395//             None
396//         };
397//
398//         // Create the blocks
399//         let mut blocks = Vec::with_capacity(config.num_blocks);
400//         for _ in 0..config.num_blocks {
401//             blocks.push(ConvNeXtBlock::<F>::new(
402//                 config.output_channels,
403//                 config.layer_scale_init_value,
404//                 config.drop_path_prob,
405//             )?);
406//         }
407//
408//         Ok(Self { downsample, blocks })
409//     }
410// }
411//
412// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXtStage<F> {
413//     fn as_any(&self) -> &dyn std::any::Any {
414//         self
415//     }
416//
417//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
418//         self
419//     }
420//
421//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
422//         // Apply downsampling if available
423//         let mut x = if let Some(ref downsample) = self.downsample {
424//             downsample.forward(input)?
425//         } else {
426//             input.clone()
427//         };
428//
429//         // Apply all blocks
430//         for block in &self.blocks {
431//             x = block.forward(&x)?;
432//         }
433//
434//         Ok(x)
435//     }
436//
437//     fn backward(
438//         &self,
439//         input: &Array<F, IxDyn>,
440//         grad_output: &Array<F, IxDyn>,
441//     ) -> Result<Array<F, IxDyn>> {
442//         let mut grad = grad_output.clone();
443//
444//         // Backward through blocks in reverse order
445//         for block in self.blocks.iter().rev() {
446//             grad = block.backward(&grad, &grad)?;
447//         }
448//
449//         // Backward through downsampling if it exists
450//         if let Some(ref downsample) = self.downsample {
451//             grad = downsample.backward(input, &grad)?;
452//         }
453//
454//         Ok(grad)
455//     }
456//
457//     fn update(&mut self, learning_rate: F) -> Result<()> {
458//         if let Some(ref mut downsample) = self.downsample {
459//             downsample.update(learning_rate)?;
460//         }
461//
462//         for block in &mut self.blocks {
463//             block.update(learning_rate)?;
464//         }
465//
466//         Ok(())
467//     }
468//
469//     fn params(&self) -> Vec<Array<F, IxDyn>> {
470//         let mut params = Vec::new();
471//         if let Some(ref downsample) = self.downsample {
472//             params.extend(downsample.params());
473//         }
474//         for block in &self.blocks {
475//             params.extend(block.params());
476//         }
477//         params
478//     }
479//
480//     fn set_training(&mut self, training: bool) {
481//         if let Some(ref mut downsample) = self.downsample {
482//             downsample.set_training(training);
483//         }
484//         for block in &mut self.blocks {
485//             block.set_training(training);
486//         }
487//     }
488//
489//     fn is_training(&self) -> bool {
490//         if let Some(ref downsample) = self.downsample {
491//             return downsample.is_training();
492//         }
493//         if !self.blocks.is_empty() {
494//             return self.blocks[0].is_training();
495//         }
496//         true
497//     }
498// }
499
500// ConvNeXt model
501// TODO: Re-enable once LayerNorm2D is implemented
502// #[derive(Debug)]
503// pub struct ConvNeXt<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
504//     /// Stem layer (initial convolution)
505//     pub stem: Sequential<F>,
506//     /// Main stages of the network
507//     pub stages: Vec<ConvNeXtStage<F>>,
508//     /// Classification head (if include_top is true)
509//     pub head: Option<Sequential<F>>,
510//     /// Model configuration
511//     pub config: ConvNeXtConfig,
512// }
513
514// TODO: Re-enable once LayerNorm2D is implemented
515// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXt<F> {
516//     /// Create a new ConvNeXt model
517//     pub fn new(config: ConvNeXtConfig) -> Result<Self> {
518//         let mut rng = scirs2_core::random::rng();
519//
520//         // Create the stem layer
521//         let mut stem = Sequential::new();
522//         stem.add(Conv2D::<F>::new(
523//             config.input_channels,
524//             config.dims[0],
525//             (4, 4),
526//             (4, 4),
527//             None,
528//         )?.with_padding(PaddingMode::Custom(0)));
529//         stem.add(LayerNorm2D::<F>::new::<SmallRng>(
530//             config.dims[0],
531//             1e-6,
532//             Some("stem_norm"),
533//         )?);
534//
535//         // Create the stages
536//         let mut stages = Vec::with_capacity(config.depths.len());
537//         let mut current_channels = config.dims[0];
538//
539//         for (i, &depth) in config.depths.iter().enumerate() {
540//             let output_channels = config.dims[i];
541//             let stride = if i == 0 { 1 } else { 2 };
542//
543//             let stage_config = ConvNeXtStageConfig {
544//                 input_channels: current_channels,
545//                 output_channels,
546//                 num_blocks: depth,
547//                 stride,
548//                 layer_scale_init_value: config.layer_scale_init_value,
549//                 drop_path_prob: 0.0,
550//             };
551//
552//             stages.push(ConvNeXtStage::<F>::new(&stage_config)?);
553//             current_channels = output_channels;
554//         }
555//
556//         // Create the head if needed
557//         let head = if config.include_top {
558//             let mut head_seq = Sequential::new();
559//
560//             head_seq.add(LayerNorm2D::<F>::new::<SmallRng>(
561//                 *config.dims.last().expect("Operation failed"),
562//                 1e-6,
563//                 Some("head_norm"),
564//             )?);
565//
566//             head_seq.add(GlobalAvgPool2D::<F>::new(Some("head_pool"))?);
567//
568//             if let Some(dropout_rate) = config.dropout_rate {
569//                 if dropout_rate > 0.0 {
570//                     head_seq.add(Dropout::<F>::new(dropout_rate, &mut rng)?);
571//                 }
572//             }
573//
574//             head_seq.add(Dense::<F>::new(
575//                 *config.dims.last().expect("Operation failed"),
576//                 config.num_classes,
577//                 Some("classifier"),
578//                 &mut rng,
579//             )?);
580//
581//             Some(head_seq)
582//         } else {
583//             None
584//         };
585//
586//         Ok(Self {
587//             stem,
588//             stages,
589//             head,
590//             config,
591//         })
592//     }
593//
594//     /// Create a ConvNeXt-Tiny model
595//     pub fn convnext_tiny(num_classes: usize, include_top: bool) -> Result<Self> {
596//         let config = ConvNeXtConfig {
597//             variant: ConvNeXtVariant::Tiny,
598//             input_channels: 3,
599//             depths: vec![3, 3, 9, 3],
600//             dims: vec![96, 192, 384, 768],
601//             num_classes,
602//             dropout_rate: Some(0.1),
603//             layer_scale_init_value: 1e-6,
604//             include_top,
605//         };
606//         Self::new(config)
607//     }
608//
609//     /// Create a ConvNeXt-Small model
610//     pub fn convnext_small(num_classes: usize, include_top: bool) -> Result<Self> {
611//         let config = ConvNeXtConfig {
612//             variant: ConvNeXtVariant::Small,
613//             input_channels: 3,
614//             depths: vec![3, 3, 27, 3],
615//             dims: vec![96, 192, 384, 768],
616//             num_classes,
617//             dropout_rate: Some(0.1),
618//             layer_scale_init_value: 1e-6,
619//             include_top,
620//         };
621//         Self::new(config)
622//     }
623//
624//     /// Create a ConvNeXt-Base model
625//     pub fn convnext_base(num_classes: usize, include_top: bool) -> Result<Self> {
626//         let config = ConvNeXtConfig {
627//             variant: ConvNeXtVariant::Base,
628//             input_channels: 3,
629//             depths: vec![3, 3, 27, 3],
630//             dims: vec![128, 256, 512, 1024],
631//             num_classes,
632//             dropout_rate: Some(0.1),
633//             layer_scale_init_value: 1e-6,
634//             include_top,
635//         };
636//         Self::new(config)
637//     }
638//
639//     /// Create a ConvNeXt-Large model
640//     pub fn convnext_large(num_classes: usize, include_top: bool) -> Result<Self> {
641//         let config = ConvNeXtConfig {
642//             variant: ConvNeXtVariant::Large,
643//             input_channels: 3,
644//             depths: vec![3, 3, 27, 3],
645//             dims: vec![192, 384, 768, 1536],
646//             num_classes,
647//             dropout_rate: Some(0.1),
648//             layer_scale_init_value: 1e-6,
649//             include_top,
650//         };
651//         Self::new(config)
652//     }
653//
654//     /// Create a ConvNeXt-XLarge model
655//     pub fn convnext_xlarge(num_classes: usize, include_top: bool) -> Result<Self> {
656//         let config = ConvNeXtConfig {
657//             variant: ConvNeXtVariant::XLarge,
658//             input_channels: 3,
659//             depths: vec![3, 3, 27, 3],
660//             dims: vec![256, 512, 1024, 2048],
661//             num_classes,
662//             dropout_rate: Some(0.1),
663//             layer_scale_init_value: 1e-6,
664//             include_top,
665//         };
666//         Self::new(config)
667//     }
668// }
669//
670// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXt<F> {
671//     fn as_any(&self) -> &dyn std::any::Any {
672//         self
673//     }
674//
675//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
676//         self
677//     }
678//
679//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
680//         // Apply stem
681//         let mut x = self.stem.forward(input)?;
682//
683//         // Apply stages
684//         for stage in &self.stages {
685//             x = stage.forward(&x)?;
686//         }
687//
688//         // Apply head if available
689//         if let Some(ref head) = self.head {
690//             x = head.forward(&x)?;
691//         }
692//
693//         Ok(x)
694//     }
695//
696//     fn backward(
697//         &self,
698//         input: &Array<F, IxDyn>,
699//         grad_output: &Array<F, IxDyn>,
700//     ) -> Result<Array<F, IxDyn>> {
701//         let mut grad = grad_output.clone();
702//
703//         // Backward through head if it exists
704//         if let Some(ref head) = self.head {
705//             grad = head.backward(&grad, &grad)?;
706//         }
707//
708//         // Backward through stages in reverse order
709//         for stage in self.stages.iter().rev() {
710//             grad = stage.backward(&grad, &grad)?;
711//         }
712//
713//         // Backward through stem
714//         self.stem.backward(input, &grad)
715//     }
716//
717//     fn update(&mut self, learning_rate: F) -> Result<()> {
718//         self.stem.update(learning_rate)?;
719//
720//         for stage in &mut self.stages {
721//             stage.update(learning_rate)?;
722//         }
723//
724//         if let Some(ref mut head) = self.head {
725//             head.update(learning_rate)?;
726//         }
727//
728//         Ok(())
729//     }
730//
731//     fn params(&self) -> Vec<Array<F, IxDyn>> {
732//         let mut params = Vec::new();
733//         params.extend(self.stem.params());
734//         for stage in &self.stages {
735//             params.extend(stage.params());
736//         }
737//         if let Some(ref head) = self.head {
738//             params.extend(head.params());
739//         }
740//         params
741//     }
742//
743//     fn set_training(&mut self, training: bool) {
744//         self.stem.set_training(training);
745//         for stage in &mut self.stages {
746//             stage.set_training(training);
747//         }
748//         if let Some(ref mut head) = self.head {
749//             head.set_training(training);
750//         }
751//     }
752//
753//     fn is_training(&self) -> bool {
754//         self.stem.is_training()
755//     }
756// }
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761
762    #[test]
763    fn test_convnext_config() {
764        let config = ConvNeXtConfig::default();
765        assert_eq!(config.variant, ConvNeXtVariant::Tiny);
766        assert_eq!(config.input_channels, 3);
767        assert_eq!(config.depths.len(), 4);
768        assert_eq!(config.dims.len(), 4);
769    }
770
771    // TODO: Re-enable once LayerNorm2D is implemented
772    // #[test]
773    // fn test_convnext_block_creation() {
774    //     let block = ConvNeXtBlock::<f64>::new(64, 1e-6, 0.0);
775    //     assert!(block.is_ok());
776    // }
777
778    // TODO: Re-enable once LayerNorm2D is implemented
779    // #[test]
780    // fn test_convnext_stage_config() {
781    //     let config = ConvNeXtStageConfig {
782    //         input_channels: 64,
783    //         output_channels: 128,
784    //         num_blocks: 3,
785    //         stride: 2,
786    //         layer_scale_init_value: 1e-6,
787    //         drop_path_prob: 0.0,
788    //     };
789    //
790    //     let stage = ConvNeXtStage::<f64>::new(&config);
791    //     assert!(stage.is_ok());
792    // }
793
794    // TODO: Re-enable once LayerNorm2D is implemented
795    // #[test]
796    // fn test_convnext_downsample() {
797    //     let downsample = ConvNeXtDownsample::<f64>::new(64, 128, 2);
798    //     assert!(downsample.is_ok());
799    // }
800
801    #[test]
802    fn test_convnext_variants() {
803        assert_eq!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Tiny);
804        assert_ne!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Base);
805    }
806}