scirs2_neural/utils/evaluation/
feature_importance.rs

1//! Feature importance visualization for machine learning models
2//!
3//! This module provides tools for visualizing and analyzing feature importance
4//! metrics from machine learning models, helping to understand which features
5//! have the greatest impact on model predictions.
6
7use crate::error::{NeuralError, Result};
8use crate::utils::colors::{
9    colored_metric_cell, colorize, gradient_color, stylize, ColorOptions, Style,
10};
11use scirs2_core::ndarray::Array1;
12use scirs2_core::numeric::Float;
13use std::fmt::{Debug, Display};
14
15/// Feature importance visualization for machine learning models
16///
17/// This struct facilitates the visualization and analysis of feature importance
18/// scores from machine learning models, helping to identify which features
19/// contribute most to predictions.
20pub struct FeatureImportance<F: Float + Debug + Display> {
21    /// Names of the features
22    pub feature_names: Vec<String>,
23    /// Importance scores for each feature
24    pub importance: Array1<F>,
25}
26
27impl<F: Float + Debug + Display> FeatureImportance<F> {
28    /// Create a new feature importance visualization
29    ///
30    /// # Arguments
31    /// * `feature_names` - Names of features
32    /// * `importance` - Importance scores
33    ///
34    /// # Returns
35    /// * `Result<FeatureImportance<F>>` - The feature importance visualization
36    ///
37    /// # Example
38    /// ```
39    /// use scirs2_core::ndarray::Array1;
40    /// use scirs2_neural::utils::evaluation::FeatureImportance;
41    ///
42    /// // Create feature names and importance scores
43    /// let feature_names = vec![
44    ///     "Age".to_string(),
45    ///     "Income".to_string(),
46    ///     "Education".to_string(),
47    ///     "Location".to_string()
48    /// ];
49    /// let importance = Array1::from_vec(vec![0.35, 0.25, 0.20, 0.10]);
50    ///
51    /// // Create feature importance visualization
52    /// let feature_importance = FeatureImportance::<f64>::new(feature_names, importance).unwrap();
53    /// ```
54    pub fn new(_featurenames: Vec<String>, importance: Array1<F>) -> Result<Self> {
55        if _featurenames.len() != importance.len() {
56            return Err(NeuralError::ValidationError(
57                "Number of feature _names must match number of importance scores".to_string(),
58            ));
59        }
60
61        Ok(FeatureImportance {
62            feature_names: _featurenames,
63            importance,
64        })
65    }
66
67    /// Get the top-k most important features
68    ///
69    /// # Arguments
70    /// * `k` - Number of top features to return
71    ///
72    /// # Returns
73    /// * `(Vec<String>, Array1<F>)` - Feature names and importance scores
74    pub fn top_k(&self, k: usize) -> (Vec<String>, Array1<F>) {
75        let k = std::cmp::min(k, self.feature_names.len());
76        let mut indices: Vec<usize> = (0..self.feature_names.len()).collect();
77
78        // Sort indices by importance (descending)
79        indices.sort_by(|&a, &b| {
80            self.importance[b]
81                .partial_cmp(&self.importance[a])
82                .unwrap_or(std::cmp::Ordering::Equal)
83        });
84
85        // Select top-k features
86        let top_indices = indices[..k].to_vec();
87        let top_names = top_indices
88            .iter()
89            .map(|&i| self.feature_names[i].clone())
90            .collect();
91        let top_importance = Array1::from_iter(top_indices.iter().map(|&i| self.importance[i]));
92
93        (top_names, top_importance)
94    }
95
96    /// Create an ASCII bar chart of feature importance
97    ///
98    /// # Arguments
99    /// * `title` - Optional title for the chart
100    /// * `width` - Width of the bar chart
101    /// * `k` - Number of top features to include (None for all)
102    ///
103    /// # Returns
104    /// * `String` - ASCII bar chart
105    pub fn to_ascii(&self, title: Option<&str>, width: usize, k: Option<usize>) -> String {
106        self.to_ascii_with_options(title, width, k, &ColorOptions::default())
107    }
108
109    /// Create an ASCII bar chart of feature importance with color options
110    /// This method creates a bar chart visualization with customizable colors,
111    /// showing feature importance scores in descending order.
112    ///
113    /// # Arguments
114    /// * `title` - Optional title for the chart
115    /// * `width` - Width of the bar chart
116    /// * `k` - Number of top features to include (None for all)
117    /// * `color_options` - Color options for visualization
118    ///
119    /// # Returns
120    /// * `String` - ASCII bar chart with colors
121    pub fn to_ascii_with_options(
122        &self,
123        title: Option<&str>,
124        width: usize,
125        k: Option<usize>,
126        color_options: &ColorOptions,
127    ) -> String {
128        let (features, importance) = if let Some(top_k) = k {
129            self.top_k(top_k)
130        } else {
131            (self.feature_names.clone(), self.importance.clone())
132        };
133
134        // Pre-allocate result string with estimated capacity
135        let mut result = String::with_capacity(features.len() * 80);
136
137        // Add title if provided
138        if let Some(titletext) = title {
139            if color_options.enabled {
140                result.push_str(&stylize(titletext, Style::Bold));
141            } else {
142                result.push_str(titletext);
143            }
144            result.push_str("\n\n");
145        }
146
147        // Determine maximum importance for scaling
148        let max_importance =
149            importance
150                .iter()
151                .copied()
152                .fold(F::zero(), |acc, v| if v > acc { v } else { acc });
153
154        // Get maximum feature name length for alignment
155        let max_name_len = features
156            .iter()
157            .map(|name| name.len())
158            .max()
159            .unwrap_or(10)
160            .max(10);
161
162        // Determine available width for bars
163        let bar_area_width = width.saturating_sub(max_name_len + 10);
164
165        // Constants for formatting
166        const ANSI_PADDING: usize = 9; // Extra space needed for ANSI color codes
167
168        // Create a sorted index
169        let mut indices: Vec<usize> = (0..features.len()).collect();
170        indices.sort_by(|&a, &b| {
171            importance[b]
172                .partial_cmp(&importance[a])
173                .unwrap_or(std::cmp::Ordering::Equal)
174        });
175
176        // Draw the bar chart
177        for &idx in &indices {
178            let feature_name = &features[idx];
179            let imp = importance[idx];
180
181            // Scale the importance to fit in the available width
182            let bar_length = if max_importance > F::zero() {
183                let ratio = (imp / max_importance).to_f64().unwrap_or(0.0);
184                (ratio * bar_area_width as f64).round() as usize
185            } else {
186                0
187            };
188
189            // Format feature name with styling
190            let formatted_name = if color_options.enabled {
191                stylize(feature_name, Style::Bold).to_string()
192            } else {
193                feature_name.clone()
194            };
195
196            // Format importance value with coloring
197            let normalized_imp = if max_importance > F::zero() {
198                (imp / max_importance).to_f64().unwrap_or(0.0)
199            } else {
200                0.0
201            };
202
203            let formatted_imp = if color_options.enabled {
204                colored_metric_cell(format!("{imp:.3}"), normalized_imp, color_options)
205            } else {
206                format!("{imp:.3}")
207            };
208
209            // Format the bar with coloring
210            let bar = if color_options.enabled {
211                if let Some(color) = gradient_color(normalized_imp, color_options) {
212                    colorize("█".repeat(bar_length), color)
213                } else {
214                    "█".repeat(bar_length)
215                }
216            } else {
217                "█".repeat(bar_length)
218            };
219
220            // Format the line with consistent width calculations
221            let name_padding = if color_options.enabled {
222                ANSI_PADDING
223            } else {
224                0
225            };
226
227            result.push_str(&format!(
228                "{:<width$} | {} | {}\n",
229                formatted_name,
230                formatted_imp,
231                bar,
232                width = max_name_len + name_padding
233            ));
234        }
235
236        result
237    }
238}