1use rand::rngs::SmallRng;
2use rand::SeedableRng;
3use scirs2_neural::error::Result;
4use scirs2_neural::layers::{Conv2D, Dense, Dropout, MaxPool2D, PaddingMode};
5use scirs2_neural::models::sequential::Sequential;
6use scirs2_neural::utils::colors::ColorOptions;
7use scirs2_neural::utils::{sequential_model_dataflow, sequential_model_summary, ModelVizOptions};
8
9fn main() -> Result<()> {
10 let mut rng = SmallRng::seed_from_u64(42);
12
13 let model = create_cnn_model(&mut rng)?;
15
16 let summary = sequential_model_summary(
18 &model,
19 Some(vec![32, 3, 224, 224]), Some("CNN Architecture"),
21 Some(ModelVizOptions {
22 width: 100,
23 show_params: true,
24 show_shapes: true,
25 show_properties: true,
26 color_options: ColorOptions {
27 enabled: true,
28 ..Default::default()
29 },
30 }),
31 )?;
32 println!("{}", summary);
33
34 let dataflow = sequential_model_dataflow(
36 &model,
37 vec![32, 3, 224, 224], Some("CNN Data Flow Diagram"),
39 Some(ModelVizOptions {
40 width: 80,
41 show_params: true,
42 show_shapes: true,
43 show_properties: false,
44 color_options: ColorOptions {
45 enabled: true,
46 ..Default::default()
47 },
48 }),
49 )?;
50 println!("\n{}", dataflow);
51
52 Ok(())
53}
54
55fn create_cnn_model<R: rand::Rng + Clone + Send + Sync + 'static>(
57 rng: &mut R,
58) -> Result<Sequential<f64>> {
59 let mut model = Sequential::new();
60
61 model.add_layer(Conv2D::new(3, 64, (3, 3), (1, 1), PaddingMode::Same, rng)?);
63 model.add_layer(Conv2D::new(64, 64, (3, 3), (1, 1), PaddingMode::Same, rng)?);
64 model.add_layer(MaxPool2D::new((2, 2), (2, 2), None)?);
65
66 model.add_layer(Conv2D::new(
68 64,
69 128,
70 (3, 3),
71 (1, 1),
72 PaddingMode::Same,
73 rng,
74 )?);
75 model.add_layer(Conv2D::new(
76 128,
77 128,
78 (3, 3),
79 (1, 1),
80 PaddingMode::Same,
81 rng,
82 )?);
83 model.add_layer(MaxPool2D::new((2, 2), (2, 2), None)?);
84
85 model.add_layer(Conv2D::new(
87 128,
88 256,
89 (3, 3),
90 (1, 1),
91 PaddingMode::Same,
92 rng,
93 )?);
94 model.add_layer(Conv2D::new(
95 256,
96 256,
97 (3, 3),
98 (1, 1),
99 PaddingMode::Same,
100 rng,
101 )?);
102 model.add_layer(MaxPool2D::new((2, 2), (2, 2), None)?);
103
104 model.add_layer(Dense::new(256 * 28 * 28, 512, Some("relu"), rng)?);
107 model.add_layer(Dropout::new(0.5, rng)?);
108 model.add_layer(Dense::new(512, 10, Some("softmax"), rng)?);
109
110 Ok(model)
111}