efficientnet_example/
efficientnet_example.rs

1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::Layer;
3use scirs2_neural::models::{EfficientNet, EfficientNetConfig};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6    println!("EfficientNet Example");
7
8    // Create EfficientNet-B0 model for image classification
9    let input_channels = 3; // RGB images
10    let num_classes = 1000; // ImageNet classes
11
12    println!(
13        "Creating EfficientNet-B0 model with {} input channels and {} output classes",
14        input_channels, num_classes
15    );
16
17    // Create model
18    let model = EfficientNet::<f32>::efficientnet_b0(input_channels, num_classes)?;
19
20    // Create dummy input (batch_size=1, channels=3, height=224, width=224)
21    let input = Array::from_shape_fn(IxDyn(&[1, input_channels, 224, 224]), |_| {
22        rand::random::<f32>()
23    });
24
25    println!("Input shape: {:?}", input.shape());
26
27    // Forward pass
28    let output = model.forward(&input)?;
29
30    println!("Output shape: {:?}", output.shape());
31    println!("Output contains logits for {} classes", output.shape()[1]);
32
33    // Create EfficientNet-B3 model (larger model)
34    println!("\nCreating EfficientNet-B3 model...");
35
36    let model_b3 = EfficientNet::<f32>::efficientnet_b3(input_channels, num_classes)?;
37
38    // Create dummy input with higher resolution for B3 (300x300)
39    let input_b3 = Array::from_shape_fn(IxDyn(&[1, input_channels, 300, 300]), |_| {
40        rand::random::<f32>()
41    });
42
43    println!("Input shape for B3: {:?}", input_b3.shape());
44
45    // Forward pass
46    let output_b3 = model_b3.forward(&input_b3)?;
47
48    println!("Output shape for B3: {:?}", output_b3.shape());
49
50    // Create a custom EfficientNet model for smaller images
51    println!("\nCreating a custom EfficientNet model for smaller images...");
52
53    // Create simplified config with fewer stages
54    let mut custom_config = EfficientNetConfig::efficientnet_b0(input_channels, 10); // 10 classes
55
56    // Simplify by keeping only first 4 stages
57    custom_config.stages.truncate(4);
58
59    // Scale down the model
60    custom_config.width_coefficient = 0.5;
61    custom_config.depth_coefficient = 0.5;
62    custom_config.resolution = 32; // For CIFAR-10 size images
63
64    let custom_model = EfficientNet::<f32>::new(custom_config)?;
65
66    // Create dummy input for small images (32x32)
67    let small_input = Array::from_shape_fn(IxDyn(&[1, input_channels, 32, 32]), |_| {
68        rand::random::<f32>()
69    });
70
71    println!("Custom input shape: {:?}", small_input.shape());
72
73    // Forward pass
74    let custom_output = custom_model.forward(&small_input)?;
75
76    println!("Custom output shape: {:?}", custom_output.shape());
77    println!(
78        "Custom model produces logits for {} classes",
79        custom_output.shape()[1]
80    );
81
82    println!("\nEfficientNet example completed successfully!");
83
84    Ok(())
85}