model_visualization_example/
model_visualization_example.rs1use ndarray::{Array, Array1, Array2};
2use rand::rngs::SmallRng;
3use rand::{Rng, SeedableRng};
4use scirs2_neural::error::Result;
5use scirs2_neural::utils::{ConfusionMatrix, FeatureImportance, LearningCurve, ROCCurve};
6
7fn main() -> Result<()> {
8 println!("Neural Network Model Evaluation Visualization Example\n");
9
10 let n_samples = 500;
12 let n_features = 10;
13 let n_classes = 4;
14
15 println!(
16 "Generating {} samples with {} features for {} classes",
17 n_samples, n_features, n_classes
18 );
19
20 println!("\n--- Confusion Matrix Visualization ---\n");
22
23 let mut rng = SmallRng::seed_from_u64(42);
25
26 let y_true = Array::from_shape_fn(n_samples, |_| rng.random_range(0..n_classes));
28
29 let y_pred = Array::from_shape_fn(n_samples, |i| {
31 if rng.random::<f32>() < 0.7 {
32 y_true[i]
34 } else {
35 rng.random_range(0..n_classes)
37 }
38 });
39
40 let class_labels = vec![
42 "Class A".to_string(),
43 "Class B".to_string(),
44 "Class C".to_string(),
45 "Class D".to_string(),
46 ];
47
48 let cm = ConfusionMatrix::<f32>::new(
49 &y_true.view(),
50 &y_pred.view(),
51 Some(n_classes),
52 Some(class_labels),
53 )?;
54
55 println!("Raw Confusion Matrix:\n");
57 println!("{}", cm.to_ascii(Some("Confusion Matrix"), false));
58
59 println!("\nNormalized Confusion Matrix:\n");
60 println!("{}", cm.to_ascii(Some("Normalized Confusion Matrix"), true));
61
62 println!("\nAccuracy: {:.3}", cm.accuracy());
64
65 let precision = cm.precision();
66 let recall = cm.recall();
67 let f1 = cm.f1_score();
68
69 println!("Per-class metrics:");
70 for i in 0..n_classes {
71 println!(
72 " Class {}: Precision={:.3}, Recall={:.3}, F1={:.3}",
73 i, precision[i], recall[i], f1[i]
74 );
75 }
76
77 println!("Macro F1 Score: {:.3}", cm.macro_f1());
78
79 println!("\n--- Feature Importance Visualization ---\n");
81
82 let feature_names = (0..n_features)
84 .map(|i| format!("Feature_{}", i))
85 .collect::<Vec<String>>();
86
87 let importance = Array1::from_shape_fn(n_features, |i| {
88 let base = (n_features - i) as f32 / n_features as f32;
90 base + 0.2 * rng.random::<f32>()
91 });
92
93 let fi = FeatureImportance::new(feature_names, importance)?;
94
95 println!("{}", fi.to_ascii(Some("Feature Importance"), 60, None));
97
98 println!("\nTop 5 Most Important Features:\n");
100 println!("{}", fi.to_ascii(Some("Top 5 Features"), 60, Some(5)));
101
102 println!("\n--- ROC Curve Visualization ---\n");
104
105 let n_binary = 200;
107 let y_true_binary = Array::from_shape_fn(n_binary, |_| rng.random_range(0..2));
108
109 let y_scores = Array1::from_shape_fn(n_binary, |i| {
111 if y_true_binary[i] == 1 {
112 0.6 + 0.4 * rng.random::<f32>()
114 } else {
115 0.4 * rng.random::<f32>()
117 }
118 });
119
120 let roc = ROCCurve::new(&y_true_binary.view(), &y_scores.view())?;
121
122 println!("ROC AUC: {:.3}", roc.auc);
123 println!("\n{}", roc.to_ascii(None, 50, 20));
124
125 println!("\n--- Learning Curve Visualization ---\n");
127
128 let n_points = 10;
130 let n_cv = 5;
131
132 let train_sizes = Array1::from_shape_fn(n_points, |i| 50 + i * 50);
133
134 let train_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
136 0.95 - 0.05 * (i as f32 / n_points as f32) + 0.03 * rng.random::<f32>()
137 });
138
139 let val_scores = Array2::from_shape_fn((n_points, n_cv), |(i, _)| {
141 0.7 + 0.2 * (i as f32 / n_points as f32) + 0.05 * rng.random::<f32>()
142 });
143
144 let lc = LearningCurve::new(train_sizes, train_scores, val_scores)?;
145
146 println!("{}", lc.to_ascii(None, 60, 20, "Accuracy"));
147
148 println!("\nModel evaluation visualizations completed successfully!");
150
151 Ok(())
152}