model_visualization_cnn/
model_visualization_cnn.rs

1use rand::rngs::SmallRng;
2use rand::SeedableRng;
3use scirs2_neural::error::Result;
4use scirs2_neural::layers::{Conv2D, Dense, Dropout, MaxPool2D, PaddingMode};
5use scirs2_neural::models::sequential::Sequential;
6use scirs2_neural::utils::colors::ColorOptions;
7use scirs2_neural::utils::{sequential_model_dataflow, sequential_model_summary, ModelVizOptions};
8
9fn main() -> Result<()> {
10    // Initialize random number generator
11    let mut rng = SmallRng::seed_from_u64(42);
12
13    // Create a CNN model
14    let model = create_cnn_model(&mut rng)?;
15
16    // Display model summary
17    let summary = sequential_model_summary(
18        &model,
19        Some(vec![32, 3, 224, 224]), // Input shape (batch_size, channels, height, width)
20        Some("CNN Architecture"),
21        Some(ModelVizOptions {
22            width: 100,
23            show_params: true,
24            show_shapes: true,
25            show_properties: true,
26            color_options: ColorOptions {
27                enabled: true,
28                ..Default::default()
29            },
30        }),
31    )?;
32    println!("{}", summary);
33
34    // Display model dataflow
35    let dataflow = sequential_model_dataflow(
36        &model,
37        vec![32, 3, 224, 224], // Input shape
38        Some("CNN Data Flow Diagram"),
39        Some(ModelVizOptions {
40            width: 80,
41            show_params: true,
42            show_shapes: true,
43            show_properties: false,
44            color_options: ColorOptions {
45                enabled: true,
46                ..Default::default()
47            },
48        }),
49    )?;
50    println!("\n{}", dataflow);
51
52    Ok(())
53}
54
55// Create a simple CNN model (VGG-like)
56fn create_cnn_model<R: rand::Rng + Clone + Send + Sync + 'static>(
57    rng: &mut R,
58) -> Result<Sequential<f64>> {
59    let mut model = Sequential::new();
60
61    // Block 1
62    model.add_layer(Conv2D::new(3, 64, (3, 3), (1, 1), PaddingMode::Same, rng)?);
63    model.add_layer(Conv2D::new(64, 64, (3, 3), (1, 1), PaddingMode::Same, rng)?);
64    model.add_layer(MaxPool2D::new((2, 2), (2, 2), None)?);
65
66    // Block 2
67    model.add_layer(Conv2D::new(
68        64,
69        128,
70        (3, 3),
71        (1, 1),
72        PaddingMode::Same,
73        rng,
74    )?);
75    model.add_layer(Conv2D::new(
76        128,
77        128,
78        (3, 3),
79        (1, 1),
80        PaddingMode::Same,
81        rng,
82    )?);
83    model.add_layer(MaxPool2D::new((2, 2), (2, 2), None)?);
84
85    // Block 3
86    model.add_layer(Conv2D::new(
87        128,
88        256,
89        (3, 3),
90        (1, 1),
91        PaddingMode::Same,
92        rng,
93    )?);
94    model.add_layer(Conv2D::new(
95        256,
96        256,
97        (3, 3),
98        (1, 1),
99        PaddingMode::Same,
100        rng,
101    )?);
102    model.add_layer(MaxPool2D::new((2, 2), (2, 2), None)?);
103
104    // Fully connected layers
105    // In a real implementation we would need to add a Flatten layer here
106    model.add_layer(Dense::new(256 * 28 * 28, 512, Some("relu"), rng)?);
107    model.add_layer(Dropout::new(0.5, rng)?);
108    model.add_layer(Dense::new(512, 10, Some("softmax"), rng)?);
109
110    Ok(model)
111}