neural_confusion_matrix/
neural_confusion_matrix.rs

1use ndarray::{Array1, Array2};
2use rand::{rngs::SmallRng, Rng, SeedableRng};
3use scirs2_neural::callbacks::{Callback, CallbackContext, CallbackTiming, VisualizationCallback};
4use scirs2_neural::error::Result;
5use scirs2_neural::layers::Dense;
6use scirs2_neural::losses::MeanSquaredError;
7use scirs2_neural::models::{sequential::Sequential, Model};
8use scirs2_neural::optimizers::Adam;
9use scirs2_neural::utils::evaluation::ConfusionMatrix;
10use std::collections::HashMap;
11use std::f32::consts::PI;
12
13// Generate a spiral dataset for multi-class classification
14fn generate_spiral_dataset(
15    n_samples: usize,
16    n_classes: usize,
17    noise: f32,
18    rng: &mut SmallRng,
19) -> (Array2<f32>, Array1<usize>) {
20    let mut x = Array2::<f32>::zeros((n_samples * n_classes, 2));
21    let mut y = Array1::<usize>::zeros(n_samples * n_classes);
22
23    for j in 0..n_classes {
24        // Angular separation between spirals
25        let r = (j as f32) * 2.0 * PI / (n_classes as f32);
26
27        for i in 0..n_samples {
28            // Generate points along a spiral
29            let t = 1.0 * (i as f32) / (n_samples as f32);
30            let radius = 2.0 * t;
31
32            // Angle
33            let theta = 1.5 * t * 2.0 * PI + r;
34
35            // Point coordinates
36            let x1 = radius * f32::cos(theta) + noise * rng.random_range(-1.0..1.0);
37            let x2 = radius * f32::sin(theta) + noise * rng.random_range(-1.0..1.0);
38
39            // Store the point and label
40            let idx = j * n_samples + i;
41            x[[idx, 0]] = x1;
42            x[[idx, 1]] = x2;
43            y[idx] = j;
44        }
45    }
46
47    (x, y)
48}
49
50// Create a simple classification model
51fn create_classification_model(
52    input_dim: usize,
53    hidden_dim: usize,
54    output_dim: usize,
55    rng: &mut SmallRng,
56) -> Result<Sequential<f32>> {
57    let mut model = Sequential::new();
58
59    // First hidden layer
60    let dense1 = Dense::new(input_dim, hidden_dim, Some("relu"), rng)?;
61    model.add_layer(dense1);
62
63    // Second hidden layer
64    let dense2 = Dense::new(hidden_dim, hidden_dim / 2, Some("relu"), rng)?;
65    model.add_layer(dense2);
66
67    // Output layer
68    let dense3 = Dense::new(hidden_dim / 2, output_dim, Some("sigmoid"), rng)?;
69    model.add_layer(dense3);
70
71    Ok(model)
72}
73
74// Convert one-hot encoded predictions to class indices
75fn predictions_to_classes(
76    predictions: &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
77) -> Array1<usize> {
78    let shape = predictions.shape();
79    let n_samples = shape[0];
80    let n_classes = shape[1];
81    let mut classes = Array1::zeros(n_samples);
82
83    for i in 0..n_samples {
84        // Create a view of the i-th row
85        let mut max_val = predictions[[i, 0]];
86        let mut max_idx = 0;
87
88        // Find the index of the highest value
89        for j in 1..n_classes {
90            let val = predictions[[i, j]];
91            if val > max_val {
92                max_val = val;
93                max_idx = j;
94            }
95        }
96
97        classes[i] = max_idx;
98    }
99
100    classes
101}
102
103// Helper to convert class indices to one-hot encoded vectors
104fn one_hot_encode(y: &Array1<usize>, n_classes: usize) -> Array2<f32> {
105    let n_samples = y.len();
106    let mut one_hot = Array2::zeros((n_samples, n_classes));
107
108    for i in 0..n_samples {
109        let class_idx = y[i];
110        if class_idx < n_classes {
111            one_hot[[i, class_idx]] = 1.0;
112        }
113    }
114
115    one_hot
116}
117
118fn main() -> Result<()> {
119    println!("Neural Network Confusion Matrix Visualization");
120    println!("==============================================\n");
121
122    // Initialize RNG with a fixed seed for reproducibility
123    let mut rng = SmallRng::seed_from_u64(42);
124
125    // Generate spiral dataset for 3-class classification
126    let n_classes = 3;
127    let n_samples_per_class = 100;
128    let noise = 0.15;
129
130    let (x, y) = generate_spiral_dataset(n_samples_per_class, n_classes, noise, &mut rng);
131    println!(
132        "Generated spiral dataset with {} classes, {} samples per class",
133        n_classes, n_samples_per_class
134    );
135
136    // Split data into training and test sets (80/20 split)
137    let n_samples = x.shape()[0];
138    let n_train = (n_samples as f32 * 0.8) as usize;
139    let n_test = n_samples - n_train;
140
141    let x_train = x.slice(ndarray::s![0..n_train, ..]).to_owned();
142    let y_train = y.slice(ndarray::s![0..n_train]).to_owned();
143    let x_test = x.slice(ndarray::s![n_train.., ..]).to_owned();
144    let y_test = y.slice(ndarray::s![n_train..]).to_owned();
145
146    println!(
147        "Split data into {} training and {} test samples",
148        n_train, n_test
149    );
150
151    // Create a classification model
152    let input_dim = 2; // 2D input (x, y coordinates)
153    let hidden_dim = 32; // Hidden layer size
154    let output_dim = n_classes; // One output per class
155
156    let mut model = create_classification_model(input_dim, hidden_dim, output_dim, &mut rng)?;
157    println!("Created model with {} layers", model.num_layers());
158
159    // Setup loss function and optimizer
160    let loss_fn = MeanSquaredError::new();
161    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
162
163    // Train the model
164    let epochs = 100;
165    let x_train_dyn = x_train.clone().into_dyn();
166    let y_train_onehot = one_hot_encode(&y_train, n_classes);
167    let y_train_onehot_dyn = y_train_onehot.into_dyn();
168
169    // Create visualization callback for training metrics
170    let mut visualization_cb = VisualizationCallback::new(10) // Show every 10 epochs
171        .with_tracked_metrics(vec![
172            "train_loss".to_string(),
173            "val_accuracy".to_string(),
174        ]);
175
176    // Define class labels for confusion matrix
177    let class_labels = vec![
178        "Class A".to_string(),
179        "Class B".to_string(),
180        "Class C".to_string(),
181    ];
182
183    // Train the model (simple manual training loop)
184    println!("\nTraining model...");
185
186    // Initialize history for tracking metrics
187    let mut epoch_history = HashMap::new();
188    epoch_history.insert("train_loss".to_string(), Vec::new());
189    epoch_history.insert("val_accuracy".to_string(), Vec::new());
190
191    // Training loop
192    for epoch in 0..epochs {
193        // Train for one epoch
194        let train_loss =
195            model.train_batch(&x_train_dyn, &y_train_onehot_dyn, &loss_fn, &mut optimizer)?;
196
197        // Compute validation accuracy
198        let x_test_dyn = x_test.clone().into_dyn();
199        let predictions = model.forward(&x_test_dyn)?;
200        let predicted_classes = predictions_to_classes(&predictions);
201
202        // Calculate validation accuracy
203        let mut correct = 0;
204        for i in 0..n_test {
205            if predicted_classes[i] == y_test[i] {
206                correct += 1;
207            }
208        }
209        let val_accuracy = correct as f32 / n_test as f32;
210
211        // Store metrics
212        epoch_history
213            .get_mut("train_loss")
214            .unwrap()
215            .push(train_loss);
216        epoch_history
217            .get_mut("val_accuracy")
218            .unwrap()
219            .push(val_accuracy);
220
221        // Print progress
222        if (epoch + 1) % 10 == 0 || epoch == 0 {
223            println!(
224                "Epoch {}/{}: loss = {:.6}, val_accuracy = {:.4}",
225                epoch + 1,
226                epochs,
227                train_loss,
228                val_accuracy
229            );
230        }
231
232        // Update visualization callback
233        let mut context = CallbackContext {
234            epoch,
235            total_epochs: epochs,
236            batch: 0,
237            total_batches: 1,
238            batch_loss: None,
239            epoch_loss: Some(train_loss),
240            val_loss: None,
241            metrics: vec![val_accuracy],
242            history: &epoch_history,
243            stop_training: false,
244            model: None,
245        };
246
247        // Visualize progress with metrics chart
248        if epoch % 10 == 0 || epoch == epochs - 1 {
249            visualization_cb.on_event(CallbackTiming::AfterEpoch, &mut context)?;
250        }
251
252        // Calculate and show confusion matrix during training
253        if epoch % 20 == 0 || epoch == epochs - 1 {
254            // Create confusion matrix
255            let cm = ConfusionMatrix::<f32>::new(
256                &y_test.view(),
257                &predicted_classes.view(),
258                Some(n_classes),
259                Some(class_labels.clone()),
260            )?;
261
262            // Show heatmap visualization
263            println!("\nConfusion Matrix at Epoch {}:", epoch + 1);
264            println!(
265                "{}",
266                cm.to_heatmap(
267                    Some(&format!("Confusion Matrix - Epoch {}", epoch + 1)),
268                    true
269                )
270            );
271        }
272    }
273
274    // Final evaluation
275    println!("\nFinal model evaluation:");
276
277    // Make predictions on test set
278    let x_test_dyn = x_test.clone().into_dyn();
279    let predictions = model.forward(&x_test_dyn)?;
280    let predicted_classes = predictions_to_classes(&predictions);
281
282    // Create confusion matrix
283    let cm = ConfusionMatrix::<f32>::new(
284        &y_test.view(),
285        &predicted_classes.view(),
286        Some(n_classes),
287        Some(class_labels.clone()),
288    )?;
289
290    // Calculate and show metrics
291    let accuracy = cm.accuracy();
292    let precision = cm.precision();
293    let recall = cm.recall();
294    let f1 = cm.f1_score();
295
296    println!("\nFinal Classification Metrics:");
297    println!("Overall Accuracy: {:.4}", accuracy);
298
299    println!("\nPer-Class Metrics:");
300    println!("Class | Precision | Recall | F1-Score");
301    println!("-----------------------------------");
302
303    for i in 0..n_classes {
304        println!(
305            "{}    | {:.4}     | {:.4}  | {:.4}",
306            class_labels[i], precision[i], recall[i], f1[i]
307        );
308    }
309
310    println!("\nMacro F1-Score: {:.4}", cm.macro_f1());
311
312    // Show different confusion matrix visualizations
313    println!("\nFinal Confusion Matrix Visualizations:");
314
315    // 1. Standard confusion matrix
316    println!("\n1. Standard Confusion Matrix:");
317    println!("{}", cm.to_ascii(Some("Final Confusion Matrix"), false));
318
319    // 2. Normalized confusion matrix
320    println!("\n2. Normalized Confusion Matrix:");
321    println!(
322        "{}",
323        cm.to_ascii(Some("Final Normalized Confusion Matrix"), true)
324    );
325
326    // 3. Confusion matrix heatmap
327    println!("\n3. Confusion Matrix Heatmap:");
328    println!(
329        "{}",
330        cm.to_heatmap(Some("Final Confusion Matrix Heatmap"), true)
331    );
332
333    // 4. Error pattern analysis
334    println!("\n4. Error Pattern Analysis:");
335    println!("{}", cm.error_heatmap(Some("Final Error Pattern Analysis")));
336
337    println!("\nNeural Network Confusion Matrix Visualization Complete!");
338    Ok(())
339}