scirs2_neural/utils/evaluation/feature_importance.rs
1//! Feature importance visualization for machine learning models
2//!
3//! This module provides tools for visualizing and analyzing feature importance
4//! metrics from machine learning models, helping to understand which features
5//! have the greatest impact on model predictions.
6
7use crate::error::{NeuralError, Result};
8use crate::utils::colors::{
9 colored_metric_cell, colorize, gradient_color, stylize, ColorOptions, Style,
10};
11use scirs2_core::ndarray::Array1;
12use scirs2_core::numeric::Float;
13use std::fmt::{Debug, Display};
14
15/// Feature importance visualization for machine learning models
16///
17/// This struct facilitates the visualization and analysis of feature importance
18/// scores from machine learning models, helping to identify which features
19/// contribute most to predictions.
20pub struct FeatureImportance<F: Float + Debug + Display> {
21 /// Names of the features
22 pub feature_names: Vec<String>,
23 /// Importance scores for each feature
24 pub importance: Array1<F>,
25}
26
27impl<F: Float + Debug + Display> FeatureImportance<F> {
28 /// Create a new feature importance visualization
29 ///
30 /// # Arguments
31 /// * `feature_names` - Names of features
32 /// * `importance` - Importance scores
33 ///
34 /// # Returns
35 /// * `Result<FeatureImportance<F>>` - The feature importance visualization
36 ///
37 /// # Example
38 /// ```
39 /// use scirs2_core::ndarray::Array1;
40 /// use scirs2_neural::utils::evaluation::FeatureImportance;
41 ///
42 /// // Create feature names and importance scores
43 /// let feature_names = vec![
44 /// "Age".to_string(),
45 /// "Income".to_string(),
46 /// "Education".to_string(),
47 /// "Location".to_string()
48 /// ];
49 /// let importance = Array1::from_vec(vec![0.35, 0.25, 0.20, 0.10]);
50 ///
51 /// // Create feature importance visualization
52 /// let feature_importance = FeatureImportance::<f64>::new(feature_names, importance).unwrap();
53 /// ```
54 pub fn new(_featurenames: Vec<String>, importance: Array1<F>) -> Result<Self> {
55 if _featurenames.len() != importance.len() {
56 return Err(NeuralError::ValidationError(
57 "Number of feature _names must match number of importance scores".to_string(),
58 ));
59 }
60
61 Ok(FeatureImportance {
62 feature_names: _featurenames,
63 importance,
64 })
65 }
66
67 /// Get the top-k most important features
68 ///
69 /// # Arguments
70 /// * `k` - Number of top features to return
71 ///
72 /// # Returns
73 /// * `(Vec<String>, Array1<F>)` - Feature names and importance scores
74 pub fn top_k(&self, k: usize) -> (Vec<String>, Array1<F>) {
75 let k = std::cmp::min(k, self.feature_names.len());
76 let mut indices: Vec<usize> = (0..self.feature_names.len()).collect();
77
78 // Sort indices by importance (descending)
79 indices.sort_by(|&a, &b| {
80 self.importance[b]
81 .partial_cmp(&self.importance[a])
82 .unwrap_or(std::cmp::Ordering::Equal)
83 });
84
85 // Select top-k features
86 let top_indices = indices[..k].to_vec();
87 let top_names = top_indices
88 .iter()
89 .map(|&i| self.feature_names[i].clone())
90 .collect();
91 let top_importance = Array1::from_iter(top_indices.iter().map(|&i| self.importance[i]));
92
93 (top_names, top_importance)
94 }
95
96 /// Create an ASCII bar chart of feature importance
97 ///
98 /// # Arguments
99 /// * `title` - Optional title for the chart
100 /// * `width` - Width of the bar chart
101 /// * `k` - Number of top features to include (None for all)
102 ///
103 /// # Returns
104 /// * `String` - ASCII bar chart
105 pub fn to_ascii(&self, title: Option<&str>, width: usize, k: Option<usize>) -> String {
106 self.to_ascii_with_options(title, width, k, &ColorOptions::default())
107 }
108
109 /// Create an ASCII bar chart of feature importance with color options
110 /// This method creates a bar chart visualization with customizable colors,
111 /// showing feature importance scores in descending order.
112 ///
113 /// # Arguments
114 /// * `title` - Optional title for the chart
115 /// * `width` - Width of the bar chart
116 /// * `k` - Number of top features to include (None for all)
117 /// * `color_options` - Color options for visualization
118 ///
119 /// # Returns
120 /// * `String` - ASCII bar chart with colors
121 pub fn to_ascii_with_options(
122 &self,
123 title: Option<&str>,
124 width: usize,
125 k: Option<usize>,
126 color_options: &ColorOptions,
127 ) -> String {
128 let (features, importance) = if let Some(top_k) = k {
129 self.top_k(top_k)
130 } else {
131 (self.feature_names.clone(), self.importance.clone())
132 };
133
134 // Pre-allocate result string with estimated capacity
135 let mut result = String::with_capacity(features.len() * 80);
136
137 // Add title if provided
138 if let Some(titletext) = title {
139 if color_options.enabled {
140 result.push_str(&stylize(titletext, Style::Bold));
141 } else {
142 result.push_str(titletext);
143 }
144 result.push_str("\n\n");
145 }
146
147 // Determine maximum importance for scaling
148 let max_importance =
149 importance
150 .iter()
151 .copied()
152 .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
153
154 // Get maximum feature name length for alignment
155 let max_name_len = features
156 .iter()
157 .map(|name| name.len())
158 .max()
159 .unwrap_or(10)
160 .max(10);
161
162 // Determine available width for bars
163 let bar_area_width = width.saturating_sub(max_name_len + 10);
164
165 // Constants for formatting
166 const ANSI_PADDING: usize = 9; // Extra space needed for ANSI color codes
167
168 // Create a sorted index
169 let mut indices: Vec<usize> = (0..features.len()).collect();
170 indices.sort_by(|&a, &b| {
171 importance[b]
172 .partial_cmp(&importance[a])
173 .unwrap_or(std::cmp::Ordering::Equal)
174 });
175
176 // Draw the bar chart
177 for &idx in &indices {
178 let feature_name = &features[idx];
179 let imp = importance[idx];
180
181 // Scale the importance to fit in the available width
182 let bar_length = if max_importance > F::zero() {
183 let ratio = (imp / max_importance).to_f64().unwrap_or(0.0);
184 (ratio * bar_area_width as f64).round() as usize
185 } else {
186 0
187 };
188
189 // Format feature name with styling
190 let formatted_name = if color_options.enabled {
191 stylize(feature_name, Style::Bold).to_string()
192 } else {
193 feature_name.clone()
194 };
195
196 // Format importance value with coloring
197 let normalized_imp = if max_importance > F::zero() {
198 (imp / max_importance).to_f64().unwrap_or(0.0)
199 } else {
200 0.0
201 };
202
203 let formatted_imp = if color_options.enabled {
204 colored_metric_cell(format!("{imp:.3}"), normalized_imp, color_options)
205 } else {
206 format!("{imp:.3}")
207 };
208
209 // Format the bar with coloring
210 let bar = if color_options.enabled {
211 if let Some(color) = gradient_color(normalized_imp, color_options) {
212 colorize("█".repeat(bar_length), color)
213 } else {
214 "█".repeat(bar_length)
215 }
216 } else {
217 "█".repeat(bar_length)
218 };
219
220 // Format the line with consistent width calculations
221 let name_padding = if color_options.enabled {
222 ANSI_PADDING
223 } else {
224 0
225 };
226
227 result.push_str(&format!(
228 "{:<width$} | {} | {}\n",
229 formatted_name,
230 formatted_imp,
231 bar,
232 width = max_name_len + name_padding
233 ));
234 }
235
236 result
237 }
238}