model_visualization_example/
model_visualization_example.rs

1use ndarray::{Array, Array1, Array2};
2use rand::rngs::SmallRng;
3use rand::{Rng, SeedableRng};
4use scirs2_neural::error::Result;
5use scirs2_neural::utils::{ConfusionMatrix, FeatureImportance, LearningCurve, ROCCurve};
6
7fn main() -> Result<()> {
8    println!("Neural Network Model Evaluation Visualization Example\n");
9
10    // Generate some example data
11    let n_samples = 500;
12    let n_features = 10;
13    let n_classes = 4;
14
15    println!(
16        "Generating {} samples with {} features for {} classes",
17        n_samples, n_features, n_classes
18    );
19
20    // 1. Confusion Matrix Example
21    println!("\n--- Confusion Matrix Visualization ---\n");
22
23    // Create a deterministic RNG for reproducibility
24    let mut rng = SmallRng::seed_from_u64(42);
25
26    // Generate random predictions and true labels
27    let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29    // Create slightly correlated predictions (not completely random)
30    let y_pred = Array::from_shape_fn(n_samples, |i| {
31        if rng.random::<f32>() < 0.7 {
32            // 70% chance of correct prediction
33            y_true[i]
34        } else {
35            // 30% chance of random class
36            rng.random_range(0..n_classes)
37        }
38    });
39
40    // Create confusion matrix
41    let class_labels = vec![
42        "Class A".to_string(),
43        "Class B".to_string(),
44        "Class C".to_string(),
45        "Class D".to_string(),
46    ];
47
48    let cm = ConfusionMatrix::<f32>::new(
49        &y_true.view(),
50        &y_pred.view(),
51        Some(n_classes),
52        Some(class_labels),
53    )?;
54
55    // Print raw and normalized confusion matrices
56    println!("Raw Confusion Matrix:\n");
57    println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59    println!("\nNormalized Confusion Matrix:\n");
60    println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62    // Print metrics
63    println!("\nAccuracy: {:.3}", cm.accuracy());
64
65    let precision = cm.precision();
66    let recall = cm.recall();
67    let f1 = cm.f1_score();
68
69    println!("Per-class metrics:");
70    for i in 0..n_classes {
71        println!(
72            "  Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73            i, precision[i], recall[i], f1[i]
74        );
75    }
76
77    println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79    // 2. Feature Importance Visualization
80    println!("\n--- Feature Importance Visualization ---\n");
81
82    // Generate random feature importance scores
83    let feature_names = (0..n_features)
84        .map(|i| format!("Feature_{}", i))
85        .collect::<Vec<String>>();
86
87    let importance = Array1::from_shape_fn(n_features, |i| {
88        // Make some features more important than others
89        let base = (n_features - i) as f32 / n_features as f32;
90        base + 0.2 * rng.random::<f32>()
91    });
92
93    let fi = FeatureImportance::new(feature_names, importance)?;
94
95    // Print full feature importance
96    println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98    // Print top-5 features
99    println!("\nTop 5 Most Important Features:\n");
100    println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102    // 3. ROC Curve for Binary Classification
103    println!("\n--- ROC Curve Visualization ---\n");
104
105    // Generate binary classification data
106    let n_binary = 200;
107    let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109    // Generate scores with some predictive power
110    let y_scores = Array1::from_shape_fn(n_binary, |i| {
111        if y_true_binary[i] == 1 {
112            // Higher scores for positive class
113            0.6 + 0.4 * rng.random::<f32>()
114        } else {
115            // Lower scores for negative class
116            0.4 * rng.random::<f32>()
117        }
118    });
119
120    let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122    println!("ROC AUC: {:.3}", roc.auc);
123    println!("\n{}", roc.to_ascii(None, 50, 20));
124
125    // 4. Learning Curve Visualization
126    println!("\n--- Learning Curve Visualization ---\n");
127
128    // Generate learning curve data
129    let n_points = 10;
130    let n_cv = 5;
131
132    let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134    // Generate training scores (decreasing with size due to overfitting)
135    let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136        0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137    });
138
139    // Generate validation scores (increasing with size)
140    let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141        0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142    });
143
144    let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146    println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148    // Print final message
149    println!("\nModel evaluation visualizations completed successfully!");
150
151    Ok(())
152}