scirs2_neural/models/architectures/convnext.rs
1//! ConvNeXt architecture implementation
2//!
3//! This module implements the ConvNeXt architecture as described in
4//! "A ConvNet for the 2020s" (<https://arxiv.org/abs/2201.03545>)
5//! ConvNeXt modernizes ResNet architecture by incorporating design choices from
6//! Vision Transformers, resulting in a pure convolutional model with excellent performance.
7
8use crate::activations::GELU;
9use crate::error::Result;
10use crate::layers::conv::PaddingMode;
11use crate::layers::{Conv2D, Dense, Dropout, GlobalAvgPool2D, Layer, Sequential};
12// Note: LayerNorm2D not yet implemented, using LayerNorm instead
13use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
14use scirs2_core::numeric::{Float, NumAssign};
15use scirs2_core::random::{rngs::SmallRng, SeedableRng};
16use serde::{Deserialize, Serialize};
17use std::fmt::Debug;
18
19/// Configuration for a ConvNeXt stage
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ConvNeXtStageConfig {
22 /// Number of input channels
23 pub input_channels: usize,
24 /// Number of output channels
25 pub output_channels: usize,
26 /// Number of blocks in this stage
27 pub num_blocks: usize,
28 /// Stride for the first block (typically 2 for downsampling, 1 otherwise)
29 pub stride: usize,
30 /// Layer scale initialization value (typically 1e-6)
31 pub layer_scale_init_value: f64,
32 /// Dropout probability
33 pub drop_path_prob: f64,
34}
35
36/// Configuration for a ConvNeXt model
37#[derive(Debug, Clone)]
38pub struct ConvNeXtConfig {
39 /// Model depth variant (Tiny, Small, Base, Large, XLarge)
40 pub variant: ConvNeXtVariant,
41 /// Number of input channels (typically 3 for RGB images)
42 pub input_channels: usize,
43 /// Depths for each stage
44 pub depths: Vec<usize>,
45 /// Dimensions (channels) for each stage
46 pub dims: Vec<usize>,
47 /// Number of output classes
48 pub num_classes: usize,
49 /// Dropout rate
50 pub dropout_rate: Option<f64>,
51 /// Layer scale initialization value
52 pub layer_scale_init_value: f64,
53 /// Whether to include the classification head
54 pub include_top: bool,
55}
56
57impl Default for ConvNeXtConfig {
58 fn default() -> Self {
59 Self {
60 variant: ConvNeXtVariant::Tiny,
61 input_channels: 3,
62 depths: vec![3, 3, 9, 3],
63 dims: vec![96, 192, 384, 768],
64 num_classes: 1000,
65 dropout_rate: Some(0.0),
66 layer_scale_init_value: 1e-6,
67 include_top: true,
68 }
69 }
70}
71
72/// ConvNeXt model variants
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
74pub enum ConvNeXtVariant {
75 /// ConvNeXt-Tiny
76 Tiny,
77 /// ConvNeXt-Small
78 Small,
79 /// ConvNeXt-Base
80 Base,
81 /// ConvNeXt-Large
82 Large,
83 /// ConvNeXt-XLarge
84 XLarge,
85}
86
87// ConvNeXt block implementation
88// TODO: Re-enable once LayerNorm2D is implemented
89// #[derive(Debug, Clone)]
90// pub struct ConvNeXtBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
91// /// Depthwise convolution
92// pub depthwise_conv: Conv2D<F>,
93// /// Layer normalization
94// pub norm: LayerNorm2D<F>,
95// /// Pointwise convolution 1
96// pub pointwise_conv1: Conv2D<F>,
97// /// GELU activation
98// pub gelu: GELU,
99// /// Pointwise convolution 2
100// pub pointwise_conv2: Conv2D<F>,
101// /// Layer scale gamma parameter
102// pub gamma: Array<F, IxDyn>,
103// /// Skip connection flag
104// pub use_skip: bool,
105// /// Skip connection scale for stochastic depth
106// pub skip_scale: F,
107// }
108
109// TODO: Re-enable once LayerNorm2D is implemented
110// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtBlock<F> {
111// /// Create a new ConvNeXtBlock
112// pub fn new(channels: usize, layer_scale_init_value: f64, drop_path_prob: f64) -> Result<Self> {
113// let mut rng = scirs2_core::random::rng();
114//
115// let depthwise_conv = Conv2D::<F>::new(
116// channels,
117// channels,
118// (7, 7),
119// (1, 1),
120// None,
121// )?.with_padding(PaddingMode::Custom(3));
122//
123// let norm = LayerNorm2D::<F>::new::<SmallRng>(channels, 1e-6, Some("norm"))?;
124//
125// let pointwise_conv1 = Conv2D::<F>::new(
126// channels,
127// channels * 4,
128// (1, 1),
129// (1, 1),
130// None,
131// )?.with_padding(PaddingMode::Custom(0));
132//
133// let gelu = GELU::new();
134//
135// let pointwise_conv2 = Conv2D::<F>::new(
136// channels * 4,
137// channels,
138// (1, 1),
139// (1, 1),
140// None,
141// )?.with_padding(PaddingMode::Custom(0));
142//
143// // Initialize gamma as a learnable parameter
144// let gamma_value = F::from(layer_scale_init_value).expect("Failed to convert to float");
145// let gamma = Array::<F, _>::from_elem([channels, 1, 1], gamma_value).into_dyn();
146//
147// // Stochastic depth rate
148// let skip_scale = F::from(1.0 - drop_path_prob).expect("Failed to convert to float");
149// let use_skip = drop_path_prob > 0.0;
150//
151// Ok(Self {
152// depthwise_conv,
153// norm,
154// pointwise_conv1,
155// gelu,
156// pointwise_conv2,
157// gamma,
158// use_skip,
159// skip_scale,
160// })
161// }
162// }
163
164// TODO: Re-enable once LayerNorm2D is implemented
165// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXtBlock<F> {
166// fn as_any(&self) -> &dyn std::any::Any {
167// self
168// }
169//
170// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
171// self
172// }
173//
174// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
175// // Save input for skip connection
176// let identity = input.clone();
177//
178// // Depthwise convolution
179// let mut x = self.depthwise_conv.forward(input)?;
180//
181// // Normalization
182// x = self.norm.forward(&x)?;
183//
184// // First pointwise convolution and activation
185// x = self.pointwise_conv1.forward(&x)?;
186// x = <GELU as Layer<F>>::forward(&self.gelu, &x)?;
187//
188// // Second pointwise convolution
189// x = self.pointwise_conv2.forward(&x)?;
190//
191// // Apply layer scale
192// let shape = x.shape().to_vec();
193// if shape.len() >= 4 {
194// let view = x
195// .clone()
196// .into_shape_with_order((shape[0], shape[1], shape[2] * shape[3]))?;
197// let scaled = view * &self.gamma;
198// x = scaled.into_shape_with_order(shape).expect("Operation failed");
199// }
200//
201// // Apply stochastic depth and skip connection
202// if self.use_skip {
203// // During training, scale the output by (1 - drop_path_prob)
204// x = x * self.skip_scale;
205// }
206//
207// // Add skip connection
208// x = x + identity;
209//
210// Ok(x)
211// }
212//
213// fn backward(
214// &self,
215// input: &Array<F, IxDyn>,
216// grad_output: &Array<F, IxDyn>,
217// ) -> Result<Array<F, IxDyn>> {
218// // ConvNeXt backward pass in reverse order of forward pass
219// let mut grad = grad_output.clone();
220//
221// // Gradient through skip connection
222// let grad_skip = grad.clone();
223//
224// // Gradient through stochastic depth scaling
225// if self.use_skip {
226// grad = grad * self.skip_scale;
227// }
228//
229// // Gradient through layer scale
230// let shape = grad.shape().to_vec();
231// if shape.len() >= 4 {
232// let grad_view = grad
233// .clone()
234// .into_shape_with_order((shape[0], shape[1], shape[2] * shape[3]))?;
235// let grad_scaled = grad_view * &self.gamma;
236// grad = grad_scaled.into_shape_with_order(shape).expect("Operation failed");
237// }
238//
239// // Backward through second pointwise convolution
240// let grad_after_conv2 = self.pointwise_conv2.backward(&grad, &grad)?;
241//
242// // Backward through GELU activation (simplified)
243// let grad_after_gelu = grad_after_conv2.clone();
244//
245// // Backward through first pointwise convolution
246// let grad_after_conv1 = self
247// .pointwise_conv1
248// .backward(&grad_after_gelu, &grad_after_gelu)?;
249//
250// // Backward through normalization
251// let grad_after_norm = self.norm.backward(&grad_after_conv1, &grad_after_conv1)?;
252//
253// // Backward through depthwise convolution
254// let grad_after_dwconv = self.depthwise_conv.backward(input, &grad_after_norm)?;
255//
256// // Combine gradient from main path and skip connection
257// let grad_input = grad_after_dwconv + grad_skip;
258//
259// Ok(grad_input)
260// }
261//
262// fn update(&mut self, learning_rate: F) -> Result<()> {
263// self.depthwise_conv.update(learning_rate)?;
264// self.norm.update(learning_rate)?;
265// self.pointwise_conv1.update(learning_rate)?;
266// self.pointwise_conv2.update(learning_rate)?;
267//
268// // Update gamma parameter
269// let small_update = F::from(0.0001).expect("Failed to convert constant to float") * learning_rate;
270// for elem in self.gamma.iter_mut() {
271// *elem = *elem - small_update;
272// }
273//
274// Ok(())
275// }
276//
277// fn params(&self) -> Vec<Array<F, IxDyn>> {
278// let mut params = Vec::new();
279// params.extend(self.depthwise_conv.params());
280// params.extend(self.norm.params());
281// params.extend(self.pointwise_conv1.params());
282// params.extend(self.pointwise_conv2.params());
283// params.push(self.gamma.clone());
284// params
285// }
286//
287// fn set_training(&mut self, training: bool) {
288// self.depthwise_conv.set_training(training);
289// self.norm.set_training(training);
290// self.pointwise_conv1.set_training(training);
291// self.pointwise_conv2.set_training(training);
292// <GELU as Layer<F>>::set_training(&mut self.gelu, training);
293// }
294//
295// fn is_training(&self) -> bool {
296// self.depthwise_conv.is_training()
297// }
298// }
299
300// ConvNeXt downsampling layer
301// TODO: Re-enable once LayerNorm2D is implemented
302// #[derive(Debug, Clone)]
303// pub struct ConvNeXtDownsample<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
304// /// Layer normalization before convolution
305// pub norm: LayerNorm2D<F>,
306// /// Convolution for downsampling
307// pub conv: Conv2D<F>,
308// }
309//
310// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtDownsample<F> {
311// /// Create a new ConvNeXtDownsample
312// pub fn new(in_channels: usize, out_channels: usize, stride: usize) -> Result<Self> {
313// let norm = LayerNorm2D::<F>::new::<SmallRng>(in_channels, 1e-6, Some("downsample_norm"))?;
314//
315// let mut rng = scirs2_core::random::rng();
316// let conv = Conv2D::<F>::new(
317// in_channels,
318// out_channels,
319// (stride, stride),
320// (stride, stride),
321// None,
322// )?.with_padding(PaddingMode::Custom(0));
323//
324// Ok(Self { norm, conv })
325// }
326// }
327//
328// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXtDownsample<F> {
329// fn as_any(&self) -> &dyn std::any::Any {
330// self
331// }
332//
333// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334// self
335// }
336//
337// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
338// let x = self.norm.forward(input)?;
339// self.conv.forward(&x)
340// }
341//
342// fn backward(
343// &self,
344// input: &Array<F, IxDyn>,
345// grad_output: &Array<F, IxDyn>,
346// ) -> Result<Array<F, IxDyn>> {
347// let grad_after_conv = self.conv.backward(grad_output, grad_output)?;
348// self.norm.backward(input, &grad_after_conv)
349// }
350//
351// fn update(&mut self, learning_rate: F) -> Result<()> {
352// self.norm.update(learning_rate)?;
353// self.conv.update(learning_rate)?;
354// Ok(())
355// }
356//
357// fn params(&self) -> Vec<Array<F, IxDyn>> {
358// let mut params = Vec::new();
359// params.extend(self.norm.params());
360// params.extend(self.conv.params());
361// params
362// }
363//
364// fn set_training(&mut self, training: bool) {
365// self.norm.set_training(training);
366// self.conv.set_training(training);
367// }
368//
369// fn is_training(&self) -> bool {
370// self.norm.is_training()
371// }
372// }
373
374// ConvNeXt stage
375// TODO: Re-enable once LayerNorm2D is implemented
376// #[derive(Debug, Clone)]
377// pub struct ConvNeXtStage<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
378// /// Downsampling layer (optional)
379// pub downsample: Option<ConvNeXtDownsample<F>>,
380// /// Blocks in this stage
381// pub blocks: Vec<ConvNeXtBlock<F>>,
382// }
383//
384// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXtStage<F> {
385// /// Create a new ConvNeXtStage
386// pub fn new(config: &ConvNeXtStageConfig) -> Result<Self> {
387// // Create the downsampling layer if needed
388// let downsample = if config.input_channels != config.output_channels || config.stride > 1 {
389// Some(ConvNeXtDownsample::<F>::new(
390// config.input_channels,
391// config.output_channels,
392// config.stride,
393// )?)
394// } else {
395// None
396// };
397//
398// // Create the blocks
399// let mut blocks = Vec::with_capacity(config.num_blocks);
400// for _ in 0..config.num_blocks {
401// blocks.push(ConvNeXtBlock::<F>::new(
402// config.output_channels,
403// config.layer_scale_init_value,
404// config.drop_path_prob,
405// )?);
406// }
407//
408// Ok(Self { downsample, blocks })
409// }
410// }
411//
412// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXtStage<F> {
413// fn as_any(&self) -> &dyn std::any::Any {
414// self
415// }
416//
417// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
418// self
419// }
420//
421// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
422// // Apply downsampling if available
423// let mut x = if let Some(ref downsample) = self.downsample {
424// downsample.forward(input)?
425// } else {
426// input.clone()
427// };
428//
429// // Apply all blocks
430// for block in &self.blocks {
431// x = block.forward(&x)?;
432// }
433//
434// Ok(x)
435// }
436//
437// fn backward(
438// &self,
439// input: &Array<F, IxDyn>,
440// grad_output: &Array<F, IxDyn>,
441// ) -> Result<Array<F, IxDyn>> {
442// let mut grad = grad_output.clone();
443//
444// // Backward through blocks in reverse order
445// for block in self.blocks.iter().rev() {
446// grad = block.backward(&grad, &grad)?;
447// }
448//
449// // Backward through downsampling if it exists
450// if let Some(ref downsample) = self.downsample {
451// grad = downsample.backward(input, &grad)?;
452// }
453//
454// Ok(grad)
455// }
456//
457// fn update(&mut self, learning_rate: F) -> Result<()> {
458// if let Some(ref mut downsample) = self.downsample {
459// downsample.update(learning_rate)?;
460// }
461//
462// for block in &mut self.blocks {
463// block.update(learning_rate)?;
464// }
465//
466// Ok(())
467// }
468//
469// fn params(&self) -> Vec<Array<F, IxDyn>> {
470// let mut params = Vec::new();
471// if let Some(ref downsample) = self.downsample {
472// params.extend(downsample.params());
473// }
474// for block in &self.blocks {
475// params.extend(block.params());
476// }
477// params
478// }
479//
480// fn set_training(&mut self, training: bool) {
481// if let Some(ref mut downsample) = self.downsample {
482// downsample.set_training(training);
483// }
484// for block in &mut self.blocks {
485// block.set_training(training);
486// }
487// }
488//
489// fn is_training(&self) -> bool {
490// if let Some(ref downsample) = self.downsample {
491// return downsample.is_training();
492// }
493// if !self.blocks.is_empty() {
494// return self.blocks[0].is_training();
495// }
496// true
497// }
498// }
499
500// ConvNeXt model
501// TODO: Re-enable once LayerNorm2D is implemented
502// #[derive(Debug)]
503// pub struct ConvNeXt<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
504// /// Stem layer (initial convolution)
505// pub stem: Sequential<F>,
506// /// Main stages of the network
507// pub stages: Vec<ConvNeXtStage<F>>,
508// /// Classification head (if include_top is true)
509// pub head: Option<Sequential<F>>,
510// /// Model configuration
511// pub config: ConvNeXtConfig,
512// }
513
514// TODO: Re-enable once LayerNorm2D is implemented
515// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ConvNeXt<F> {
516// /// Create a new ConvNeXt model
517// pub fn new(config: ConvNeXtConfig) -> Result<Self> {
518// let mut rng = scirs2_core::random::rng();
519//
520// // Create the stem layer
521// let mut stem = Sequential::new();
522// stem.add(Conv2D::<F>::new(
523// config.input_channels,
524// config.dims[0],
525// (4, 4),
526// (4, 4),
527// None,
528// )?.with_padding(PaddingMode::Custom(0)));
529// stem.add(LayerNorm2D::<F>::new::<SmallRng>(
530// config.dims[0],
531// 1e-6,
532// Some("stem_norm"),
533// )?);
534//
535// // Create the stages
536// let mut stages = Vec::with_capacity(config.depths.len());
537// let mut current_channels = config.dims[0];
538//
539// for (i, &depth) in config.depths.iter().enumerate() {
540// let output_channels = config.dims[i];
541// let stride = if i == 0 { 1 } else { 2 };
542//
543// let stage_config = ConvNeXtStageConfig {
544// input_channels: current_channels,
545// output_channels,
546// num_blocks: depth,
547// stride,
548// layer_scale_init_value: config.layer_scale_init_value,
549// drop_path_prob: 0.0,
550// };
551//
552// stages.push(ConvNeXtStage::<F>::new(&stage_config)?);
553// current_channels = output_channels;
554// }
555//
556// // Create the head if needed
557// let head = if config.include_top {
558// let mut head_seq = Sequential::new();
559//
560// head_seq.add(LayerNorm2D::<F>::new::<SmallRng>(
561// *config.dims.last().expect("Operation failed"),
562// 1e-6,
563// Some("head_norm"),
564// )?);
565//
566// head_seq.add(GlobalAvgPool2D::<F>::new(Some("head_pool"))?);
567//
568// if let Some(dropout_rate) = config.dropout_rate {
569// if dropout_rate > 0.0 {
570// head_seq.add(Dropout::<F>::new(dropout_rate, &mut rng)?);
571// }
572// }
573//
574// head_seq.add(Dense::<F>::new(
575// *config.dims.last().expect("Operation failed"),
576// config.num_classes,
577// Some("classifier"),
578// &mut rng,
579// )?);
580//
581// Some(head_seq)
582// } else {
583// None
584// };
585//
586// Ok(Self {
587// stem,
588// stages,
589// head,
590// config,
591// })
592// }
593//
594// /// Create a ConvNeXt-Tiny model
595// pub fn convnext_tiny(num_classes: usize, include_top: bool) -> Result<Self> {
596// let config = ConvNeXtConfig {
597// variant: ConvNeXtVariant::Tiny,
598// input_channels: 3,
599// depths: vec![3, 3, 9, 3],
600// dims: vec![96, 192, 384, 768],
601// num_classes,
602// dropout_rate: Some(0.1),
603// layer_scale_init_value: 1e-6,
604// include_top,
605// };
606// Self::new(config)
607// }
608//
609// /// Create a ConvNeXt-Small model
610// pub fn convnext_small(num_classes: usize, include_top: bool) -> Result<Self> {
611// let config = ConvNeXtConfig {
612// variant: ConvNeXtVariant::Small,
613// input_channels: 3,
614// depths: vec![3, 3, 27, 3],
615// dims: vec![96, 192, 384, 768],
616// num_classes,
617// dropout_rate: Some(0.1),
618// layer_scale_init_value: 1e-6,
619// include_top,
620// };
621// Self::new(config)
622// }
623//
624// /// Create a ConvNeXt-Base model
625// pub fn convnext_base(num_classes: usize, include_top: bool) -> Result<Self> {
626// let config = ConvNeXtConfig {
627// variant: ConvNeXtVariant::Base,
628// input_channels: 3,
629// depths: vec![3, 3, 27, 3],
630// dims: vec![128, 256, 512, 1024],
631// num_classes,
632// dropout_rate: Some(0.1),
633// layer_scale_init_value: 1e-6,
634// include_top,
635// };
636// Self::new(config)
637// }
638//
639// /// Create a ConvNeXt-Large model
640// pub fn convnext_large(num_classes: usize, include_top: bool) -> Result<Self> {
641// let config = ConvNeXtConfig {
642// variant: ConvNeXtVariant::Large,
643// input_channels: 3,
644// depths: vec![3, 3, 27, 3],
645// dims: vec![192, 384, 768, 1536],
646// num_classes,
647// dropout_rate: Some(0.1),
648// layer_scale_init_value: 1e-6,
649// include_top,
650// };
651// Self::new(config)
652// }
653//
654// /// Create a ConvNeXt-XLarge model
655// pub fn convnext_xlarge(num_classes: usize, include_top: bool) -> Result<Self> {
656// let config = ConvNeXtConfig {
657// variant: ConvNeXtVariant::XLarge,
658// input_channels: 3,
659// depths: vec![3, 3, 27, 3],
660// dims: vec![256, 512, 1024, 2048],
661// num_classes,
662// dropout_rate: Some(0.1),
663// layer_scale_init_value: 1e-6,
664// include_top,
665// };
666// Self::new(config)
667// }
668// }
669//
670// impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ConvNeXt<F> {
671// fn as_any(&self) -> &dyn std::any::Any {
672// self
673// }
674//
675// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
676// self
677// }
678//
679// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
680// // Apply stem
681// let mut x = self.stem.forward(input)?;
682//
683// // Apply stages
684// for stage in &self.stages {
685// x = stage.forward(&x)?;
686// }
687//
688// // Apply head if available
689// if let Some(ref head) = self.head {
690// x = head.forward(&x)?;
691// }
692//
693// Ok(x)
694// }
695//
696// fn backward(
697// &self,
698// input: &Array<F, IxDyn>,
699// grad_output: &Array<F, IxDyn>,
700// ) -> Result<Array<F, IxDyn>> {
701// let mut grad = grad_output.clone();
702//
703// // Backward through head if it exists
704// if let Some(ref head) = self.head {
705// grad = head.backward(&grad, &grad)?;
706// }
707//
708// // Backward through stages in reverse order
709// for stage in self.stages.iter().rev() {
710// grad = stage.backward(&grad, &grad)?;
711// }
712//
713// // Backward through stem
714// self.stem.backward(input, &grad)
715// }
716//
717// fn update(&mut self, learning_rate: F) -> Result<()> {
718// self.stem.update(learning_rate)?;
719//
720// for stage in &mut self.stages {
721// stage.update(learning_rate)?;
722// }
723//
724// if let Some(ref mut head) = self.head {
725// head.update(learning_rate)?;
726// }
727//
728// Ok(())
729// }
730//
731// fn params(&self) -> Vec<Array<F, IxDyn>> {
732// let mut params = Vec::new();
733// params.extend(self.stem.params());
734// for stage in &self.stages {
735// params.extend(stage.params());
736// }
737// if let Some(ref head) = self.head {
738// params.extend(head.params());
739// }
740// params
741// }
742//
743// fn set_training(&mut self, training: bool) {
744// self.stem.set_training(training);
745// for stage in &mut self.stages {
746// stage.set_training(training);
747// }
748// if let Some(ref mut head) = self.head {
749// head.set_training(training);
750// }
751// }
752//
753// fn is_training(&self) -> bool {
754// self.stem.is_training()
755// }
756// }
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761
762 #[test]
763 fn test_convnext_config() {
764 let config = ConvNeXtConfig::default();
765 assert_eq!(config.variant, ConvNeXtVariant::Tiny);
766 assert_eq!(config.input_channels, 3);
767 assert_eq!(config.depths.len(), 4);
768 assert_eq!(config.dims.len(), 4);
769 }
770
771 // TODO: Re-enable once LayerNorm2D is implemented
772 // #[test]
773 // fn test_convnext_block_creation() {
774 // let block = ConvNeXtBlock::<f64>::new(64, 1e-6, 0.0);
775 // assert!(block.is_ok());
776 // }
777
778 // TODO: Re-enable once LayerNorm2D is implemented
779 // #[test]
780 // fn test_convnext_stage_config() {
781 // let config = ConvNeXtStageConfig {
782 // input_channels: 64,
783 // output_channels: 128,
784 // num_blocks: 3,
785 // stride: 2,
786 // layer_scale_init_value: 1e-6,
787 // drop_path_prob: 0.0,
788 // };
789 //
790 // let stage = ConvNeXtStage::<f64>::new(&config);
791 // assert!(stage.is_ok());
792 // }
793
794 // TODO: Re-enable once LayerNorm2D is implemented
795 // #[test]
796 // fn test_convnext_downsample() {
797 // let downsample = ConvNeXtDownsample::<f64>::new(64, 128, 2);
798 // assert!(downsample.is_ok());
799 // }
800
801 #[test]
802 fn test_convnext_variants() {
803 assert_eq!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Tiny);
804 assert_ne!(ConvNeXtVariant::Tiny, ConvNeXtVariant::Base);
805 }
806}