convnext_example/
convnext_example.rs1use ndarray::Array;
2use scirs2_neural::{
3 error::Result,
4 models::architectures::{ConvNeXt, ConvNeXtConfig, ConvNeXtVariant},
5 prelude::*,
6};
7
8fn main() -> Result<()> {
9 println!("ConvNeXt Example");
10 println!("----------------");
11
12 let input_shape = [1, 3, 224, 224];
14 let mut input = Array::<f32, _>::zeros(input_shape).into_dyn();
15
16 for elem in input.iter_mut() {
18 *elem = rand::random::<f32>();
19 }
20
21 println!("\nConvNeXt-Tiny:");
23 let convnext_tiny = ConvNeXt::convnext_tiny(1000, true)?;
24 let output_tiny = convnext_tiny.forward(&input)?;
25 println!("Output shape: {:?}", output_tiny.shape());
26
27 println!("\nConvNeXt-Small:");
29 let convnext_small = ConvNeXt::convnext_small(1000, true)?;
30 let output_small = convnext_small.forward(&input)?;
31 println!("Output shape: {:?}", output_small.shape());
32
33 println!("\nConvNeXt-Base:");
35 let convnext_base = ConvNeXt::convnext_base(1000, true)?;
36 let output_base = convnext_base.forward(&input)?;
37 println!("Output shape: {:?}", output_base.shape());
38
39 println!("\nConvNeXt-Large:");
41 let convnext_large = ConvNeXt::convnext_large(1000, true)?;
42 let output_large = convnext_large.forward(&input)?;
43 println!("Output shape: {:?}", output_large.shape());
44
45 println!("\nCustom ConvNeXt:");
47 let custom_config = ConvNeXtConfig {
48 variant: ConvNeXtVariant::Tiny,
49 input_channels: 3,
50 depths: vec![3, 3, 9, 3],
51 dims: vec![96, 192, 384, 768],
52 num_classes: 10,
53 dropout_rate: Some(0.2),
54 layer_scale_init_value: 1e-6,
55 include_top: true,
56 };
57
58 let custom_convnext = ConvNeXt::new(custom_config)?;
59 let output_custom = custom_convnext.forward(&input)?;
60 println!("Output shape: {:?}", output_custom.shape());
61
62 println!("\nInference example with ConvNeXt-Tiny:");
64 let inference_input = Array::<f32, _>::zeros(input_shape).into_dyn();
65 let inference_output = convnext_tiny.forward(&inference_input)?;
66
67 let mut max_val = f32::MIN;
69 let mut max_idx = 0;
70
71 for (i, &val) in inference_output.iter().enumerate() {
72 if val > max_val {
73 max_val = val;
74 max_idx = i;
75 }
76 }
77
78 println!(
79 "Predicted class: {} with confidence: {:.4}",
80 max_idx, max_val
81 );
82
83 Ok(())
84}