efficientnet_example/
efficientnet_example.rs1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::Layer;
3use scirs2_neural::models::{EfficientNet, EfficientNetConfig};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6 println!("EfficientNet Example");
7
8 let input_channels = 3; let num_classes = 1000; println!(
13 "Creating EfficientNet-B0 model with {} input channels and {} output classes",
14 input_channels, num_classes
15 );
16
17 let model = EfficientNet::<f32>::efficientnet_b0(input_channels, num_classes)?;
19
20 let input = Array::from_shape_fn(IxDyn(&[1, input_channels, 224, 224]), |_| {
22 rand::random::<f32>()
23 });
24
25 println!("Input shape: {:?}", input.shape());
26
27 let output = model.forward(&input)?;
29
30 println!("Output shape: {:?}", output.shape());
31 println!("Output contains logits for {} classes", output.shape()[1]);
32
33 println!("\nCreating EfficientNet-B3 model...");
35
36 let model_b3 = EfficientNet::<f32>::efficientnet_b3(input_channels, num_classes)?;
37
38 let input_b3 = Array::from_shape_fn(IxDyn(&[1, input_channels, 300, 300]), |_| {
40 rand::random::<f32>()
41 });
42
43 println!("Input shape for B3: {:?}", input_b3.shape());
44
45 let output_b3 = model_b3.forward(&input_b3)?;
47
48 println!("Output shape for B3: {:?}", output_b3.shape());
49
50 println!("\nCreating a custom EfficientNet model for smaller images...");
52
53 let mut custom_config = EfficientNetConfig::efficientnet_b0(input_channels, 10); custom_config.stages.truncate(4);
58
59 custom_config.width_coefficient = 0.5;
61 custom_config.depth_coefficient = 0.5;
62 custom_config.resolution = 32; let custom_model = EfficientNet::<f32>::new(custom_config)?;
65
66 let small_input = Array::from_shape_fn(IxDyn(&[1, input_channels, 32, 32]), |_| {
68 rand::random::<f32>()
69 });
70
71 println!("Custom input shape: {:?}", small_input.shape());
72
73 let custom_output = custom_model.forward(&small_input)?;
75
76 println!("Custom output shape: {:?}", custom_output.shape());
77 println!(
78 "Custom model produces logits for {} classes",
79 custom_output.shape()[1]
80 );
81
82 println!("\nEfficientNet example completed successfully!");
83
84 Ok(())
85}