1use ndarray::{Array1, Array2};
2use rand::{rngs::SmallRng, Rng, SeedableRng};
3use scirs2_neural::callbacks::{Callback, CallbackContext, CallbackTiming, VisualizationCallback};
4use scirs2_neural::error::Result;
5use scirs2_neural::layers::Dense;
6use scirs2_neural::losses::MeanSquaredError;
7use scirs2_neural::models::{sequential::Sequential, Model};
8use scirs2_neural::optimizers::Adam;
9use scirs2_neural::utils::evaluation::ConfusionMatrix;
10use std::collections::HashMap;
11use std::f32::consts::PI;
12
13fn generate_spiral_dataset(
15 n_samples: usize,
16 n_classes: usize,
17 noise: f32,
18 rng: &mut SmallRng,
19) -> (Array2<f32>, Array1<usize>) {
20 let mut x = Array2::<f32>::zeros((n_samples * n_classes, 2));
21 let mut y = Array1::<usize>::zeros(n_samples * n_classes);
22
23 for j in 0..n_classes {
24 let r = (j as f32) * 2.0 * PI / (n_classes as f32);
26
27 for i in 0..n_samples {
28 let t = 1.0 * (i as f32) / (n_samples as f32);
30 let radius = 2.0 * t;
31
32 let theta = 1.5 * t * 2.0 * PI + r;
34
35 let x1 = radius * f32::cos(theta) + noise * rng.random_range(-1.0..1.0);
37 let x2 = radius * f32::sin(theta) + noise * rng.random_range(-1.0..1.0);
38
39 let idx = j * n_samples + i;
41 x[[idx, 0]] = x1;
42 x[[idx, 1]] = x2;
43 y[idx] = j;
44 }
45 }
46
47 (x, y)
48}
49
50fn create_classification_model(
52 input_dim: usize,
53 hidden_dim: usize,
54 output_dim: usize,
55 rng: &mut SmallRng,
56) -> Result<Sequential<f32>> {
57 let mut model = Sequential::new();
58
59 let dense1 = Dense::new(input_dim, hidden_dim, Some("relu"), rng)?;
61 model.add_layer(dense1);
62
63 let dense2 = Dense::new(hidden_dim, hidden_dim / 2, Some("relu"), rng)?;
65 model.add_layer(dense2);
66
67 let dense3 = Dense::new(hidden_dim / 2, output_dim, Some("sigmoid"), rng)?;
69 model.add_layer(dense3);
70
71 Ok(model)
72}
73
74fn predictions_to_classes(
76 predictions: &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
77) -> Array1<usize> {
78 let shape = predictions.shape();
79 let n_samples = shape[0];
80 let n_classes = shape[1];
81 let mut classes = Array1::zeros(n_samples);
82
83 for i in 0..n_samples {
84 let mut max_val = predictions[[i, 0]];
86 let mut max_idx = 0;
87
88 for j in 1..n_classes {
90 let val = predictions[[i, j]];
91 if val > max_val {
92 max_val = val;
93 max_idx = j;
94 }
95 }
96
97 classes[i] = max_idx;
98 }
99
100 classes
101}
102
103fn one_hot_encode(y: &Array1<usize>, n_classes: usize) -> Array2<f32> {
105 let n_samples = y.len();
106 let mut one_hot = Array2::zeros((n_samples, n_classes));
107
108 for i in 0..n_samples {
109 let class_idx = y[i];
110 if class_idx < n_classes {
111 one_hot[[i, class_idx]] = 1.0;
112 }
113 }
114
115 one_hot
116}
117
118fn main() -> Result<()> {
119 println!("Neural Network Confusion Matrix Visualization");
120 println!("==============================================\n");
121
122 let mut rng = SmallRng::seed_from_u64(42);
124
125 let n_classes = 3;
127 let n_samples_per_class = 100;
128 let noise = 0.15;
129
130 let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131 println!(
132 "Generated spiral dataset with {} classes, {} samples per class",
133 n_classes, n_samples_per_class
134 );
135
136 let n_samples = x.shape()[0];
138 let n_train = (n_samples as f32 * 0.8) as usize;
139 let n_test = n_samples - n_train;
140
141 let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142 let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143 let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144 let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146 println!(
147 "Split data into {} training and {} test samples",
148 n_train, n_test
149 );
150
151 let input_dim = 2; let hidden_dim = 32; let output_dim = n_classes; let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157 println!("Created model with {} layers", model.num_layers());
158
159 let loss_fn = MeanSquaredError::new();
161 let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163 let epochs = 100;
165 let x_train_dyn = x_train.clone().into_dyn();
166 let y_train_onehot = one_hot_encode(&y_train, n_classes);
167 let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169 let mut visualization_cb = VisualizationCallback::new(10) .with_tracked_metrics(vec![
172 "train_loss".to_string(),
173 "val_accuracy".to_string(),
174 ]);
175
176 let class_labels = vec![
178 "Class A".to_string(),
179 "Class B".to_string(),
180 "Class C".to_string(),
181 ];
182
183 println!("\nTraining model...");
185
186 let mut epoch_history = HashMap::new();
188 epoch_history.insert("train_loss".to_string(), Vec::new());
189 epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191 for epoch in 0..epochs {
193 let train_loss =
195 model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197 let x_test_dyn = x_test.clone().into_dyn();
199 let predictions = model.forward(&x_test_dyn)?;
200 let predicted_classes = predictions_to_classes(&predictions);
201
202 let mut correct = 0;
204 for i in 0..n_test {
205 if predicted_classes[i] == y_test[i] {
206 correct += 1;
207 }
208 }
209 let val_accuracy = correct as f32 / n_test as f32;
210
211 epoch_history
213 .get_mut("train_loss")
214 .unwrap()
215 .push(train_loss);
216 epoch_history
217 .get_mut("val_accuracy")
218 .unwrap()
219 .push(val_accuracy);
220
221 if (epoch + 1) % 10 == 0 || epoch == 0 {
223 println!(
224 "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225 epoch + 1,
226 epochs,
227 train_loss,
228 val_accuracy
229 );
230 }
231
232 let mut context = CallbackContext {
234 epoch,
235 total_epochs: epochs,
236 batch: 0,
237 total_batches: 1,
238 batch_loss: None,
239 epoch_loss: Some(train_loss),
240 val_loss: None,
241 metrics: vec![val_accuracy],
242 history: &epoch_history,
243 stop_training: false,
244 model: None,
245 };
246
247 if epoch % 10 == 0 || epoch == epochs - 1 {
249 visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250 }
251
252 if epoch % 20 == 0 || epoch == epochs - 1 {
254 let cm = ConfusionMatrix::<f32>::new(
256 &y_test.view(),
257 &predicted_classes.view(),
258 Some(n_classes),
259 Some(class_labels.clone()),
260 )?;
261
262 println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264 println!(
265 "{}",
266 cm.to_heatmap(
267 Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268 true
269 )
270 );
271 }
272 }
273
274 println!("\nFinal model evaluation:");
276
277 let x_test_dyn = x_test.clone().into_dyn();
279 let predictions = model.forward(&x_test_dyn)?;
280 let predicted_classes = predictions_to_classes(&predictions);
281
282 let cm = ConfusionMatrix::<f32>::new(
284 &y_test.view(),
285 &predicted_classes.view(),
286 Some(n_classes),
287 Some(class_labels.clone()),
288 )?;
289
290 let accuracy = cm.accuracy();
292 let precision = cm.precision();
293 let recall = cm.recall();
294 let f1 = cm.f1_score();
295
296 println!("\nFinal Classification Metrics:");
297 println!("Overall Accuracy: {:.4}", accuracy);
298
299 println!("\nPer-Class Metrics:");
300 println!("Class | Precision | Recall | F1-Score");
301 println!("-----------------------------------");
302
303 for i in 0..n_classes {
304 println!(
305 "{} | {:.4} | {:.4} | {:.4}",
306 class_labels[i], precision[i], recall[i], f1[i]
307 );
308 }
309
310 println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312 println!("\nFinal Confusion Matrix Visualizations:");
314
315 println!("\n1. Standard Confusion Matrix:");
317 println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319 println!("\n2. Normalized Confusion Matrix:");
321 println!(
322 "{}",
323 cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324 );
325
326 println!("\n3. Confusion Matrix Heatmap:");
328 println!(
329 "{}",
330 cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331 );
332
333 println!("\n4. Error Pattern Analysis:");
335 println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337 println!("\nNeural Network Confusion Matrix Visualization Complete!");
338 Ok(())
339}