model_architecture_visualization/
model_architecture_visualization.rs1use rand::rngs::SmallRng;
2use rand::SeedableRng;
3use scirs2_neural::error::Result;
4use scirs2_neural::layers::{BatchNorm, Conv2D, Dense, Dropout, PaddingMode};
5use scirs2_neural::models::sequential::Sequential;
6use scirs2_neural::utils::colors::ColorOptions;
7use scirs2_neural::utils::{sequential_model_dataflow, sequential_model_summary, ModelVizOptions};
8
9fn main() -> Result<()> {
10 println!("Model Architecture Visualization Example");
11 println!("=======================================\n");
12
13 let mut rng = SmallRng::seed_from_u64(42);
15
16 println!("\n--- Example 1: MLP Architecture ---\n");
18 let mlp = create_mlp_model(&mut rng)?;
19
20 let mlp_summary = sequential_model_summary(
22 &mlp,
23 Some(vec![32, 784]), Some("MLP Neural Network"),
25 Some(ModelVizOptions {
26 width: 80,
27 show_params: true,
28 show_shapes: true,
29 show_properties: true,
30 color_options: ColorOptions::default(),
31 }),
32 )?;
33 println!("{}", mlp_summary);
34
35 let mlp_dataflow = sequential_model_dataflow(
37 &mlp,
38 vec![32, 784], Some("MLP Data Flow"),
40 None, )?;
42 println!("\n{}", mlp_dataflow);
43
44 println!("\n--- Example 2: CNN Architecture ---\n");
46 let cnn = create_cnn_model(&mut rng)?;
47
48 let mut color_options = ColorOptions::default();
50 color_options.enabled = true; color_options.use_bright = true;
52
53 let cnn_summary = sequential_model_summary(
54 &cnn,
55 Some(vec![32, 28, 28, 1]), Some("CNN Neural Network"),
57 Some(ModelVizOptions {
58 width: 80,
59 show_params: true,
60 show_shapes: true,
61 show_properties: true,
62 color_options,
63 }),
64 )?;
65 println!("{}", cnn_summary);
66
67 let cnn_dataflow = sequential_model_dataflow(
69 &cnn,
70 vec![32, 28, 28, 1], Some("CNN Data Flow"),
72 Some(ModelVizOptions {
73 width: 80,
74 show_params: true,
75 show_shapes: true,
76 show_properties: false,
77 color_options,
78 }),
79 )?;
80 println!("\n{}", cnn_dataflow);
81
82 println!("\n--- Example 3: RNN (LSTM) Architecture ---\n");
84 println!("Skipping RNN example due to threading constraints with LSTM implementation.");
85
86 println!("\nModel Architecture Visualization Complete!");
87 Ok(())
88}
89
90fn create_mlp_model(rng: &mut SmallRng) -> Result<Sequential<f32>> {
92 let mut model = Sequential::new();
93
94 let dense1 = Dense::new(784, 512, Some("relu"), rng)?;
96 model.add_layer(dense1);
97
98 let dropout1 = Dropout::new(0.2, rng)?;
99 model.add_layer(dropout1);
100
101 let dense2 = Dense::new(512, 256, Some("relu"), rng)?;
102 model.add_layer(dense2);
103
104 let dense3 = Dense::new(256, 128, Some("relu"), rng)?;
105 model.add_layer(dense3);
106
107 let dropout2 = Dropout::new(0.3, rng)?;
108 model.add_layer(dropout2);
109
110 let dense4 = Dense::new(128, 10, Some("softmax"), rng)?;
112 model.add_layer(dense4);
113
114 Ok(model)
115}
116
117fn create_cnn_model(rng: &mut SmallRng) -> Result<Sequential<f32>> {
119 let mut model = Sequential::new();
120
121 let conv1 = Conv2D::new(
123 1, 32, (3, 3), (1, 1), PaddingMode::Custom(1), rng,
129 )?;
130 model.add_layer(conv1);
131
132 let batch_norm1 = BatchNorm::new(32, 0.99, 1e-5, rng)?;
133 model.add_layer(batch_norm1);
134
135 let conv2 = Conv2D::new(
137 32, 64, (3, 3), (2, 2), PaddingMode::Custom(1), rng,
143 )?;
144 model.add_layer(conv2);
145
146 let batch_norm2 = BatchNorm::new(64, 0.99, 1e-5, rng)?;
147 model.add_layer(batch_norm2);
148
149 let dropout1 = Dropout::new(0.25, rng)?;
150 model.add_layer(dropout1);
151
152 let dense1 = Dense::new(64 * 14 * 14, 128, Some("relu"), rng)?;
158 model.add_layer(dense1);
159
160 let dropout2 = Dropout::new(0.5, rng)?;
161 model.add_layer(dropout2);
162
163 let dense2 = Dense::new(128, 10, Some("softmax"), rng)?;
165 model.add_layer(dense2);
166
167 Ok(model)
168}
169
170