sklears_inspection/visualization/
plotting_functions.rs

1//! 2D Plotting Functions Module
2//!
3//! This module provides comprehensive 2D plotting functionality for model interpretability,
4//! consolidating feature importance plots, SHAP visualizations, partial dependence plots,
5//! and comparative analysis visualizations in a single, convenient interface.
6//!
7//! ## Key Features
8//!
9//! - **Feature Importance Plots**: Interactive bar, horizontal, radial, and tree-map visualizations
10//! - **SHAP Visualizations**: Waterfall, force layout, summary, dependence, and beeswarm plots
11//! - **Partial Dependence Plots**: PDP curves with optional ICE (Individual Conditional Expectation) curves
12//! - **Comparative Plots**: Side-by-side model comparisons and overlay visualizations
13//! - **High Performance**: Leverages SciRS2 for optimized numerical computations
14//! - **Comprehensive Validation**: Input validation and detailed error messages
15//!
16//! ## Usage Examples
17//!
18//! ```rust,ignore
19//! use sklears_inspection::visualization::plotting_functions::*;
20//! // ✅ SciRS2 Policy Compliant Import
21//! use scirs2_core::ndarray::array;
22//!
23//! // Feature importance plot
24//! let importance = array![0.3, 0.5, 0.2];
25//! let features = vec!["Feature1".to_string(), "Feature2".to_string(), "Feature3".to_string()];
26//! let config = PlotConfig::default();
27//!
28//! let plot = create_feature_importance_plot(
29//!     &importance.view(),
30//!     Some(&features),
31//!     None,
32//!     &config,
33//!     FeatureImportanceType::Bar
34//! ).unwrap();
35//!
36//! // SHAP plot
37//! let shap_values = array![[0.1, 0.2, -0.1], [0.3, -0.1, 0.2]];
38//! let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]];
39//!
40//! let shap_plot = create_shap_visualization(
41//!     &shap_values.view(),
42//!     &feature_values.view(),
43//!     None,
44//!     None,
45//!     &config,
46//!     ShapPlotType::Summary,
47//! ).unwrap();
48//! ```
49
50use crate::{Float, SklResult};
51// ✅ SciRS2 Policy Compliant Import
52use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
53use std::collections::HashMap;
54
55use super::config_types::{
56    ComparativePlot, ComparisonType, FeatureImportancePlot, FeatureImportanceType,
57    PartialDependencePlot, PlotConfig, ShapPlot, ShapPlotType,
58};
59
60// =============================================================================
61// Feature Importance Plotting Functions
62// =============================================================================
63
64/// Create interactive feature importance plot
65///
66/// Generates interactive visualizations of feature importance scores with support
67/// for error bars, multiple plot types, and comprehensive validation.
68///
69/// # Arguments
70///
71/// * `importance_values` - Feature importance scores as 1D array
72/// * `feature_names` - Optional feature names; generates default names if None
73/// * `std_values` - Optional standard deviations for error bars
74/// * `config` - Plot configuration settings
75/// * `plot_type` - Type of feature importance visualization
76///
77/// # Returns
78///
79/// Result containing feature importance plot data or error
80///
81/// # Errors
82///
83/// - `InvalidInput` - If feature names length doesn't match importance values
84/// - `InvalidInput` - If std_values length doesn't match importance values
85///
86/// # Examples
87///
88/// ```rust,ignore
89/// use sklears_inspection::visualization::plotting_functions::*;
90/// // ✅ SciRS2 Policy Compliant Import
91/// use scirs2_core::ndarray::array;
92///
93/// let importance = array![0.3, 0.5, 0.2];
94/// let features = vec!["Feature1".to_string(), "Feature2".to_string(), "Feature3".to_string()];
95/// let config = PlotConfig::default();
96///
97/// let plot = create_feature_importance_plot(
98///     &importance.view(),
99///     Some(&features),
100///     None,
101///     &config,
102///     FeatureImportanceType::Bar
103/// ).unwrap();
104///
105/// assert_eq!(plot.feature_names.len(), 3);
106/// assert_eq!(plot.importance_values.len(), 3);
107/// assert_eq!(plot.importance_values[1], 0.5);
108/// ```
109pub fn create_feature_importance_plot(
110    importance_values: &ArrayView1<Float>,
111    feature_names: Option<&[String]>,
112    std_values: Option<&ArrayView1<Float>>,
113    config: &PlotConfig,
114    plot_type: FeatureImportanceType,
115) -> SklResult<FeatureImportancePlot> {
116    let n_features = importance_values.len();
117
118    // Validate input dimensions
119    if n_features == 0 {
120        return Err(crate::SklearsError::InvalidInput(
121            "Importance values cannot be empty".to_string(),
122        ));
123    }
124
125    // Generate or validate feature names
126    let feature_names = if let Some(names) = feature_names {
127        if names.len() != n_features {
128            return Err(crate::SklearsError::InvalidInput(format!(
129                "Feature names length ({}) does not match importance values length ({})",
130                names.len(),
131                n_features
132            )));
133        }
134        names.to_vec()
135    } else {
136        (0..n_features).map(|i| format!("Feature_{}", i)).collect()
137    };
138
139    // Validate standard deviations if provided
140    if let Some(std) = std_values {
141        if std.len() != n_features {
142            return Err(crate::SklearsError::InvalidInput(format!(
143                "Standard deviation values length ({}) does not match importance values length ({})",
144                std.len(),
145                n_features
146            )));
147        }
148
149        // Check for negative standard deviations
150        for (i, &val) in std.iter().enumerate() {
151            if val < 0.0 {
152                return Err(crate::SklearsError::InvalidInput(format!(
153                    "Standard deviation at index {} is negative: {}",
154                    i, val
155                )));
156            }
157        }
158    }
159
160    let std_values = std_values.map(|std| std.to_vec());
161    let importance_values = importance_values.to_vec();
162
163    Ok(FeatureImportancePlot {
164        feature_names,
165        importance_values,
166        std_values,
167        config: config.clone(),
168        plot_type,
169    })
170}
171
172/// Create advanced feature importance plot with ranking and filtering
173///
174/// Enhanced version that provides additional functionality such as automatic ranking,
175/// filtering by threshold, and statistical significance testing.
176///
177/// # Arguments
178///
179/// * `importance_values` - Feature importance scores
180/// * `feature_names` - Optional feature names
181/// * `std_values` - Optional standard deviations
182/// * `config` - Plot configuration
183/// * `plot_type` - Type of visualization
184/// * `top_k` - Optional limit to top K features (None for all features)
185/// * `min_threshold` - Optional minimum importance threshold for inclusion
186///
187/// # Returns
188///
189/// Result containing filtered and ranked feature importance plot data
190pub fn create_ranked_feature_importance_plot(
191    importance_values: &ArrayView1<Float>,
192    feature_names: Option<&[String]>,
193    std_values: Option<&ArrayView1<Float>>,
194    config: &PlotConfig,
195    plot_type: FeatureImportanceType,
196    top_k: Option<usize>,
197    min_threshold: Option<Float>,
198) -> SklResult<FeatureImportancePlot> {
199    let n_features = importance_values.len();
200
201    if n_features == 0 {
202        return Err(crate::SklearsError::InvalidInput(
203            "Importance values cannot be empty".to_string(),
204        ));
205    }
206
207    // Create feature names if not provided
208    let feature_names = if let Some(names) = feature_names {
209        if names.len() != n_features {
210            return Err(crate::SklearsError::InvalidInput(format!(
211                "Feature names length ({}) does not match importance values length ({})",
212                names.len(),
213                n_features
214            )));
215        }
216        names.to_vec()
217    } else {
218        (0..n_features).map(|i| format!("Feature_{}", i)).collect()
219    };
220
221    // Create indices and sort by importance (descending)
222    let mut indices: Vec<usize> = (0..n_features).collect();
223    indices.sort_by(|&a, &b| {
224        importance_values[b]
225            .partial_cmp(&importance_values[a])
226            .unwrap_or(std::cmp::Ordering::Equal)
227    });
228
229    // Apply threshold filtering
230    if let Some(threshold) = min_threshold {
231        indices.retain(|&i| importance_values[i].abs() >= threshold);
232    }
233
234    // Apply top-k filtering
235    if let Some(k) = top_k {
236        indices.truncate(k);
237    }
238
239    if indices.is_empty() {
240        return Err(crate::SklearsError::InvalidInput(
241            "No features meet the filtering criteria".to_string(),
242        ));
243    }
244
245    // Extract filtered and sorted data
246    let filtered_names = indices.iter().map(|&i| feature_names[i].clone()).collect();
247    let filtered_importance = indices.iter().map(|&i| importance_values[i]).collect();
248    let filtered_std = std_values.map(|std| indices.iter().map(|&i| std[i]).collect());
249
250    Ok(FeatureImportancePlot {
251        feature_names: filtered_names,
252        importance_values: filtered_importance,
253        std_values: filtered_std,
254        config: config.clone(),
255        plot_type,
256    })
257}
258
259// =============================================================================
260// SHAP Plotting Functions
261// =============================================================================
262
263/// Create interactive SHAP plot for model explanations
264///
265/// Generates SHAP (SHapley Additive exPlanations) visualizations for understanding
266/// model predictions with comprehensive validation and multiple plot types.
267///
268/// # Arguments
269///
270/// * `shap_values` - SHAP values matrix (instances × features)
271/// * `feature_values` - Feature values matrix (instances × features)
272/// * `feature_names` - Optional feature names; generates defaults if None
273/// * `instance_names` - Optional instance names; generates defaults if None
274/// * `config` - Plot configuration settings
275/// * `plot_type` - Type of SHAP visualization
276///
277/// # Returns
278///
279/// Result containing SHAP plot data or error
280///
281/// # Errors
282///
283/// - `InvalidInput` - If SHAP and feature values dimensions don't match
284/// - `InvalidInput` - If feature/instance names lengths don't match data dimensions
285///
286/// # Examples
287///
288/// ```rust,ignore
289/// use sklears_inspection::visualization::plotting_functions::*;
290/// // ✅ SciRS2 Policy Compliant Import
291/// use scirs2_core::ndarray::array;
292///
293/// let shap_values = array![[0.1, 0.2, -0.1], [0.3, -0.1, 0.2]];
294/// let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]];
295/// let config = PlotConfig::default();
296///
297/// let plot = create_shap_visualization(
298///     &shap_values.view(),
299///     &feature_values.view(),
300///     None,
301///     None,
302///     &config,
303///     ShapPlotType::Summary,
304/// ).unwrap();
305///
306/// assert_eq!(plot.shap_values.shape(), &[2, 3]);
307/// assert_eq!(plot.feature_names.len(), 3);
308/// assert_eq!(plot.instance_names.len(), 2);
309/// ```
310pub fn create_shap_visualization(
311    shap_values: &ArrayView2<Float>,
312    feature_values: &ArrayView2<Float>,
313    feature_names: Option<&[String]>,
314    instance_names: Option<&[String]>,
315    config: &PlotConfig,
316    plot_type: ShapPlotType,
317) -> SklResult<ShapPlot> {
318    let (n_instances, n_features) = shap_values.dim();
319
320    // Validate input dimensions
321    if n_instances == 0 || n_features == 0 {
322        return Err(crate::SklearsError::InvalidInput(
323            "SHAP values cannot have zero dimensions".to_string(),
324        ));
325    }
326
327    if feature_values.dim() != (n_instances, n_features) {
328        return Err(crate::SklearsError::InvalidInput(format!(
329            "SHAP values shape {:?} and feature values shape {:?} do not match",
330            (n_instances, n_features),
331            feature_values.dim()
332        )));
333    }
334
335    // Generate or validate feature names
336    let feature_names = if let Some(names) = feature_names {
337        if names.len() != n_features {
338            return Err(crate::SklearsError::InvalidInput(format!(
339                "Feature names length ({}) does not match number of features ({})",
340                names.len(),
341                n_features
342            )));
343        }
344        names.to_vec()
345    } else {
346        (0..n_features).map(|i| format!("Feature_{}", i)).collect()
347    };
348
349    // Generate or validate instance names
350    let instance_names = if let Some(names) = instance_names {
351        if names.len() != n_instances {
352            return Err(crate::SklearsError::InvalidInput(format!(
353                "Instance names length ({}) does not match number of instances ({})",
354                names.len(),
355                n_instances
356            )));
357        }
358        names.to_vec()
359    } else {
360        (0..n_instances)
361            .map(|i| format!("Instance_{}", i))
362            .collect()
363    };
364
365    Ok(ShapPlot {
366        shap_values: shap_values.to_owned(),
367        feature_values: feature_values.to_owned(),
368        feature_names,
369        instance_names,
370        config: config.clone(),
371        plot_type,
372    })
373}
374
375/// Create SHAP summary plot with aggregated feature importance
376///
377/// Specialized function for creating SHAP summary plots that aggregate
378/// importance across all instances and provide statistical summaries.
379///
380/// # Arguments
381///
382/// * `shap_values` - SHAP values matrix (instances × features)
383/// * `feature_values` - Feature values matrix (instances × features)
384/// * `feature_names` - Optional feature names
385/// * `config` - Plot configuration
386/// * `show_distribution` - Whether to include value distribution information
387///
388/// # Returns
389///
390/// Result containing aggregated SHAP summary plot
391pub fn create_shap_summary_plot(
392    shap_values: &ArrayView2<Float>,
393    feature_values: &ArrayView2<Float>,
394    feature_names: Option<&[String]>,
395    config: &PlotConfig,
396    show_distribution: bool,
397) -> SklResult<ShapPlot> {
398    let (n_instances, n_features) = shap_values.dim();
399
400    if n_instances == 0 || n_features == 0 {
401        return Err(crate::SklearsError::InvalidInput(
402            "SHAP values cannot have zero dimensions".to_string(),
403        ));
404    }
405
406    if feature_values.dim() != (n_instances, n_features) {
407        return Err(crate::SklearsError::InvalidInput(
408            "SHAP values and feature values dimensions do not match".to_string(),
409        ));
410    }
411
412    let feature_names = if let Some(names) = feature_names {
413        if names.len() != n_features {
414            return Err(crate::SklearsError::InvalidInput(
415                "Feature names length does not match number of features".to_string(),
416            ));
417        }
418        names.to_vec()
419    } else {
420        (0..n_features).map(|i| format!("Feature_{}", i)).collect()
421    };
422
423    let instance_names = (0..n_instances)
424        .map(|i| format!("Instance_{}", i))
425        .collect();
426
427    let plot_type = if show_distribution {
428        ShapPlotType::Beeswarm
429    } else {
430        ShapPlotType::Summary
431    };
432
433    Ok(ShapPlot {
434        shap_values: shap_values.to_owned(),
435        feature_values: feature_values.to_owned(),
436        feature_names,
437        instance_names,
438        config: config.clone(),
439        plot_type,
440    })
441}
442
443// =============================================================================
444// Partial Dependence Plotting Functions
445// =============================================================================
446
447/// Create partial dependence plot (PDP) with optional ICE curves
448///
449/// Generates partial dependence plots showing how model predictions change
450/// with feature values, with optional Individual Conditional Expectation curves.
451///
452/// # Arguments
453///
454/// * `feature_values` - Feature values for x-axis (sorted grid points)
455/// * `pd_values` - Partial dependence values corresponding to feature values
456/// * `ice_curves` - Optional ICE curves (instances × feature_values)
457/// * `feature_name` - Name of the feature being analyzed
458/// * `config` - Plot configuration settings
459/// * `show_ice` - Whether to display individual ICE curves
460///
461/// # Returns
462///
463/// Result containing partial dependence plot data or error
464///
465/// # Errors
466///
467/// - `InvalidInput` - If feature values and PD values lengths don't match
468/// - `InvalidInput` - If ICE curves columns don't match feature values length
469///
470/// # Examples
471///
472/// ```rust,ignore
473/// use sklears_inspection::visualization::plotting_functions::*;
474/// // ✅ SciRS2 Policy Compliant Import
475/// use scirs2_core::ndarray::array;
476///
477/// let feature_values = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
478/// let pd_values = array![0.1, 0.3, 0.5, 0.4, 0.2, 0.1];
479/// let config = PlotConfig::default();
480///
481/// let plot = create_partial_dependence_plot(
482///     &feature_values.view(),
483///     &pd_values.view(),
484///     None,
485///     "feature_1",
486///     &config,
487///     false,
488/// ).unwrap();
489///
490/// assert_eq!(plot.feature_name, "feature_1");
491/// assert_eq!(plot.feature_values.len(), 6);
492/// assert_eq!(plot.pd_values.len(), 6);
493/// assert!(!plot.show_ice);
494/// ```
495pub fn create_partial_dependence_plot(
496    feature_values: &ArrayView1<Float>,
497    pd_values: &ArrayView1<Float>,
498    ice_curves: Option<&ArrayView2<Float>>,
499    feature_name: &str,
500    config: &PlotConfig,
501    show_ice: bool,
502) -> SklResult<PartialDependencePlot> {
503    let n_points = feature_values.len();
504
505    // Validate input dimensions
506    if n_points == 0 {
507        return Err(crate::SklearsError::InvalidInput(
508            "Feature values cannot be empty".to_string(),
509        ));
510    }
511
512    if pd_values.len() != n_points {
513        return Err(crate::SklearsError::InvalidInput(format!(
514            "Feature values length ({}) and PD values length ({}) must match",
515            n_points,
516            pd_values.len()
517        )));
518    }
519
520    // Validate ICE curves if provided
521    if let Some(ice) = ice_curves {
522        if ice.ncols() != n_points {
523            return Err(crate::SklearsError::InvalidInput(format!(
524                "ICE curves columns ({}) must match feature values length ({})",
525                ice.ncols(),
526                n_points
527            )));
528        }
529
530        if ice.nrows() == 0 {
531            return Err(crate::SklearsError::InvalidInput(
532                "ICE curves cannot have zero instances".to_string(),
533            ));
534        }
535    }
536
537    // Validate feature values are sorted (for proper PD interpretation)
538    for i in 1..n_points {
539        if feature_values[i] < feature_values[i - 1] {
540            return Err(crate::SklearsError::InvalidInput(
541                "Feature values must be sorted in ascending order for proper PD interpretation"
542                    .to_string(),
543            ));
544        }
545    }
546
547    Ok(PartialDependencePlot {
548        feature_values: feature_values.to_owned(),
549        pd_values: pd_values.to_owned(),
550        ice_curves: ice_curves.map(|ice| ice.to_owned()),
551        feature_name: feature_name.to_string(),
552        config: config.clone(),
553        show_ice: show_ice && ice_curves.is_some(),
554    })
555}
556
557/// Create 2D partial dependence plot for feature interaction analysis
558///
559/// Creates a 2D PDP showing how two features interact to affect model predictions.
560/// Returns data suitable for contour plots, heatmaps, or 3D surface visualization.
561///
562/// # Arguments
563///
564/// * `feature1_values` - Values for first feature (x-axis)
565/// * `feature2_values` - Values for second feature (y-axis)
566/// * `pd_surface` - 2D partial dependence values (feature1 × feature2)
567/// * `feature1_name` - Name of first feature
568/// * `feature2_name` - Name of second feature
569/// * `config` - Plot configuration
570///
571/// # Returns
572///
573/// Result containing 2D partial dependence data structured for visualization
574pub fn create_2d_partial_dependence_plot(
575    feature1_values: &ArrayView1<Float>,
576    feature2_values: &ArrayView1<Float>,
577    pd_surface: &ArrayView2<Float>,
578    feature1_name: &str,
579    feature2_name: &str,
580    config: &PlotConfig,
581) -> SklResult<ComparativePlot> {
582    let n_points1 = feature1_values.len();
583    let n_points2 = feature2_values.len();
584
585    if n_points1 == 0 || n_points2 == 0 {
586        return Err(crate::SklearsError::InvalidInput(
587            "Feature values cannot be empty".to_string(),
588        ));
589    }
590
591    if pd_surface.dim() != (n_points1, n_points2) {
592        return Err(crate::SklearsError::InvalidInput(format!(
593            "PD surface shape {:?} does not match expected shape ({}, {})",
594            pd_surface.dim(),
595            n_points1,
596            n_points2
597        )));
598    }
599
600    let mut model_data = HashMap::new();
601    model_data.insert("2D_PD_Surface".to_string(), pd_surface.to_owned());
602
603    let labels = vec![feature1_name.to_string(), feature2_name.to_string()];
604
605    Ok(ComparativePlot {
606        model_data,
607        labels,
608        config: config.clone(),
609        comparison_type: ComparisonType::Heatmap,
610    })
611}
612
613// =============================================================================
614// Comparative Plotting Functions
615// =============================================================================
616
617/// Create comparative visualization for model comparison
618///
619/// Generates comparative plots for analyzing differences between multiple models
620/// or different parameter configurations with comprehensive validation.
621///
622/// # Arguments
623///
624/// * `model_data` - HashMap of model names to their prediction/score data
625/// * `labels` - Labels for data dimensions/features
626/// * `config` - Plot configuration settings
627/// * `comparison_type` - Type of comparison visualization
628///
629/// # Returns
630///
631/// Result containing comparative plot data or error
632///
633/// # Errors
634///
635/// - `InvalidInput` - If model data is empty
636/// - `InvalidInput` - If model data shapes are inconsistent
637///
638/// # Examples
639///
640/// ```rust,ignore
641/// use sklears_inspection::visualization::plotting_functions::*;
642/// use std::collections::HashMap;
643/// // ✅ SciRS2 Policy Compliant Import
644/// use scirs2_core::ndarray::array;
645///
646/// let mut model_data = HashMap::new();
647/// model_data.insert("model_1".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
648/// model_data.insert("model_2".to_string(), array![[2.0, 3.0], [4.0, 5.0]]);
649///
650/// let labels = vec!["Feature A".to_string(), "Feature B".to_string()];
651/// let config = PlotConfig::default();
652///
653/// let plot = create_comparative_plot(
654///     model_data,
655///     labels,
656///     &config,
657///     ComparisonType::SideBySide
658/// ).unwrap();
659///
660/// assert_eq!(plot.model_data.len(), 2);
661/// assert_eq!(plot.labels.len(), 2);
662/// ```
663pub fn create_comparative_plot(
664    model_data: HashMap<String, Array2<Float>>,
665    labels: Vec<String>,
666    config: &PlotConfig,
667    comparison_type: ComparisonType,
668) -> SklResult<ComparativePlot> {
669    // Validate non-empty model data
670    if model_data.is_empty() {
671        return Err(crate::SklearsError::InvalidInput(
672            "Model data cannot be empty".to_string(),
673        ));
674    }
675
676    // Validate labels are not empty
677    if labels.is_empty() {
678        return Err(crate::SklearsError::InvalidInput(
679            "Labels cannot be empty".to_string(),
680        ));
681    }
682
683    // Validate that all model data has compatible dimensions
684    let first_entry = model_data.iter().next().unwrap();
685    let (first_name, first_data) = first_entry;
686    let expected_shape = first_data.dim();
687
688    // Check for empty data arrays
689    if expected_shape.0 == 0 || expected_shape.1 == 0 {
690        return Err(crate::SklearsError::InvalidInput(format!(
691            "Model '{}' has invalid data shape: {:?}",
692            first_name, expected_shape
693        )));
694    }
695
696    // Validate all models have consistent shapes
697    for (model_name, data) in &model_data {
698        let current_shape = data.dim();
699        if current_shape != expected_shape {
700            return Err(crate::SklearsError::InvalidInput(format!(
701                "Model '{}' data shape {:?} does not match expected shape {:?}",
702                model_name, current_shape, expected_shape
703            )));
704        }
705
706        // Validate for non-finite values
707        for value in data.iter() {
708            if !value.is_finite() {
709                return Err(crate::SklearsError::InvalidInput(format!(
710                    "Model '{}' contains non-finite values",
711                    model_name
712                )));
713            }
714        }
715    }
716
717    // Validate labels count matches data dimensions
718    if labels.len() != expected_shape.1 {
719        return Err(crate::SklearsError::InvalidInput(format!(
720            "Labels count ({}) does not match data columns ({})",
721            labels.len(),
722            expected_shape.1
723        )));
724    }
725
726    Ok(ComparativePlot {
727        model_data,
728        labels,
729        config: config.clone(),
730        comparison_type,
731    })
732}
733
734/// Create performance comparison plot for multiple metrics
735///
736/// Specialized comparative plot for model performance metrics with statistical
737/// significance testing and confidence intervals.
738///
739/// # Arguments
740///
741/// * `performance_data` - Performance metrics for each model
742/// * `metric_names` - Names of the performance metrics
743/// * `confidence_intervals` - Optional confidence intervals for each metric
744/// * `config` - Plot configuration
745/// * `show_significance` - Whether to show statistical significance markers
746///
747/// # Returns
748///
749/// Result containing performance comparison plot data
750pub fn create_performance_comparison_plot(
751    performance_data: HashMap<String, Array1<Float>>,
752    metric_names: Vec<String>,
753    confidence_intervals: Option<HashMap<String, Array2<Float>>>,
754    config: &PlotConfig,
755    show_significance: bool,
756) -> SklResult<ComparativePlot> {
757    if performance_data.is_empty() {
758        return Err(crate::SklearsError::InvalidInput(
759            "Performance data cannot be empty".to_string(),
760        ));
761    }
762
763    if metric_names.is_empty() {
764        return Err(crate::SklearsError::InvalidInput(
765            "Metric names cannot be empty".to_string(),
766        ));
767    }
768
769    // Convert 1D performance data to 2D for compatibility with ComparativePlot
770    let mut model_data_2d = HashMap::new();
771    let expected_len = metric_names.len();
772
773    for (model_name, metrics) in performance_data {
774        if metrics.len() != expected_len {
775            return Err(crate::SklearsError::InvalidInput(format!(
776                "Model '{}' metrics length ({}) does not match expected length ({})",
777                model_name,
778                metrics.len(),
779                expected_len
780            )));
781        }
782
783        // Convert to 2D array (1 × n_metrics)
784        let metrics_2d = metrics.insert_axis(scirs2_core::ndarray::Axis(0));
785        model_data_2d.insert(model_name, metrics_2d);
786    }
787
788    // Validate confidence intervals if provided
789    if let Some(ci) = &confidence_intervals {
790        for (model_name, intervals) in ci {
791            if !model_data_2d.contains_key(model_name) {
792                return Err(crate::SklearsError::InvalidInput(format!(
793                    "Confidence interval provided for unknown model: '{}'",
794                    model_name
795                )));
796            }
797
798            if intervals.dim() != (2, expected_len) {
799                return Err(crate::SklearsError::InvalidInput(format!(
800                    "Confidence intervals for model '{}' have incorrect shape: {:?}, expected (2, {})",
801                    model_name,
802                    intervals.dim(),
803                    expected_len
804                )));
805            }
806        }
807    }
808
809    let comparison_type = if show_significance {
810        ComparisonType::Statistical
811    } else {
812        ComparisonType::SideBySide
813    };
814
815    Ok(ComparativePlot {
816        model_data: model_data_2d,
817        labels: metric_names,
818        config: config.clone(),
819        comparison_type,
820    })
821}
822
823// =============================================================================
824// Comprehensive Tests
825// =============================================================================
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    // ✅ SciRS2 Policy Compliant Import
831    use scirs2_core::ndarray::array;
832
833    // Feature Importance Tests
834    #[test]
835    fn test_feature_importance_plot_creation() {
836        let importance = array![0.3, 0.5, 0.2];
837        let features = vec![
838            "Feature1".to_string(),
839            "Feature2".to_string(),
840            "Feature3".to_string(),
841        ];
842        let config = PlotConfig::default();
843
844        let plot = create_feature_importance_plot(
845            &importance.view(),
846            Some(&features),
847            None,
848            &config,
849            FeatureImportanceType::Bar,
850        )
851        .unwrap();
852
853        assert_eq!(plot.feature_names.len(), 3);
854        assert_eq!(plot.importance_values.len(), 3);
855        assert_eq!(plot.importance_values[1], 0.5);
856        assert!(plot.std_values.is_none());
857        assert_eq!(plot.plot_type, FeatureImportanceType::Bar);
858    }
859
860    #[test]
861    fn test_feature_importance_with_std() {
862        let importance = array![0.3, 0.5, 0.2];
863        let std_vals = array![0.1, 0.05, 0.15];
864        let config = PlotConfig::default();
865
866        let plot = create_feature_importance_plot(
867            &importance.view(),
868            None,
869            Some(&std_vals.view()),
870            &config,
871            FeatureImportanceType::Horizontal,
872        )
873        .unwrap();
874
875        assert_eq!(plot.feature_names.len(), 3);
876        assert!(plot.std_values.is_some());
877        assert_eq!(plot.std_values.as_ref().unwrap().len(), 3);
878        assert_eq!(plot.plot_type, FeatureImportanceType::Horizontal);
879    }
880
881    #[test]
882    fn test_feature_importance_dimension_mismatch() {
883        let importance = array![0.3, 0.5];
884        let features = vec![
885            "Feature1".to_string(),
886            "Feature2".to_string(),
887            "Feature3".to_string(),
888        ];
889        let config = PlotConfig::default();
890
891        let result = create_feature_importance_plot(
892            &importance.view(),
893            Some(&features),
894            None,
895            &config,
896            FeatureImportanceType::Bar,
897        );
898        assert!(result.is_err());
899    }
900
901    #[test]
902    fn test_feature_importance_empty_input() {
903        let importance = array![];
904        let config = PlotConfig::default();
905
906        let result = create_feature_importance_plot(
907            &importance.view(),
908            None,
909            None,
910            &config,
911            FeatureImportanceType::Bar,
912        );
913        assert!(result.is_err());
914    }
915
916    #[test]
917    fn test_feature_importance_negative_std() {
918        let importance = array![0.3, 0.5, 0.2];
919        let std_vals = array![0.1, -0.05, 0.15]; // negative std
920        let config = PlotConfig::default();
921
922        let result = create_feature_importance_plot(
923            &importance.view(),
924            None,
925            Some(&std_vals.view()),
926            &config,
927            FeatureImportanceType::Bar,
928        );
929        assert!(result.is_err());
930    }
931
932    #[test]
933    fn test_ranked_feature_importance_plot() {
934        let importance = array![0.1, 0.5, 0.3, 0.2];
935        let config = PlotConfig::default();
936
937        let plot = create_ranked_feature_importance_plot(
938            &importance.view(),
939            None,
940            None,
941            &config,
942            FeatureImportanceType::Bar,
943            Some(2),    // top 2 features
944            Some(0.15), // minimum threshold
945        )
946        .unwrap();
947
948        // Should have only top 2 features above threshold (0.5 and 0.3)
949        assert_eq!(plot.feature_names.len(), 2);
950        assert_eq!(plot.importance_values.len(), 2);
951        assert_eq!(plot.importance_values[0], 0.5); // highest first
952        assert_eq!(plot.importance_values[1], 0.3); // second highest
953    }
954
955    // SHAP Tests
956    #[test]
957    fn test_shap_plot_creation() {
958        let shap_values = array![[0.1, 0.2, -0.1], [0.3, -0.1, 0.2]];
959        let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]];
960        let config = PlotConfig::default();
961
962        let plot = create_shap_visualization(
963            &shap_values.view(),
964            &feature_values.view(),
965            None,
966            None,
967            &config,
968            ShapPlotType::Summary,
969        )
970        .unwrap();
971
972        assert_eq!(plot.shap_values.shape(), &[2, 3]);
973        assert_eq!(plot.feature_names.len(), 3);
974        assert_eq!(plot.instance_names.len(), 2);
975        assert_eq!(plot.plot_type, ShapPlotType::Summary);
976    }
977
978    #[test]
979    fn test_shap_plot_dimension_mismatch() {
980        let shap_values = array![[0.1, 0.2], [0.3, -0.1]];
981        let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]];
982        let config = PlotConfig::default();
983
984        let result = create_shap_visualization(
985            &shap_values.view(),
986            &feature_values.view(),
987            None,
988            None,
989            &config,
990            ShapPlotType::Summary,
991        );
992        assert!(result.is_err());
993    }
994
995    #[test]
996    fn test_shap_plot_zero_dimensions() {
997        let shap_values = array![[], []]; // 2x0 array
998        let feature_values = array![[], []]; // 2x0 array
999        let config = PlotConfig::default();
1000
1001        let result = create_shap_visualization(
1002            &shap_values.view(),
1003            &feature_values.view(),
1004            None,
1005            None,
1006            &config,
1007            ShapPlotType::Summary,
1008        );
1009        assert!(result.is_err());
1010    }
1011
1012    #[test]
1013    fn test_shap_summary_plot() {
1014        let shap_values = array![[0.1, 0.2, -0.1], [0.3, -0.1, 0.2], [0.0, 0.1, -0.05]];
1015        let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5], [1.2, 2.1, 3.2]];
1016        let config = PlotConfig::default();
1017
1018        let plot = create_shap_summary_plot(
1019            &shap_values.view(),
1020            &feature_values.view(),
1021            None,
1022            &config,
1023            true, // show distribution
1024        )
1025        .unwrap();
1026
1027        assert_eq!(plot.shap_values.shape(), &[3, 3]);
1028        assert_eq!(plot.plot_type, ShapPlotType::Beeswarm);
1029    }
1030
1031    // Partial Dependence Tests
1032    #[test]
1033    fn test_partial_dependence_plot_creation() {
1034        let feature_values = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
1035        let pd_values = array![0.1, 0.3, 0.5, 0.4, 0.2, 0.1];
1036        let config = PlotConfig::default();
1037
1038        let plot = create_partial_dependence_plot(
1039            &feature_values.view(),
1040            &pd_values.view(),
1041            None,
1042            "feature_1",
1043            &config,
1044            false,
1045        )
1046        .unwrap();
1047
1048        assert_eq!(plot.feature_name, "feature_1");
1049        assert_eq!(plot.feature_values.len(), 6);
1050        assert_eq!(plot.pd_values.len(), 6);
1051        assert!(!plot.show_ice);
1052        assert!(plot.ice_curves.is_none());
1053    }
1054
1055    #[test]
1056    fn test_partial_dependence_plot_with_ice() {
1057        let feature_values = array![0.0, 0.5, 1.0];
1058        let pd_values = array![0.1, 0.5, 0.2];
1059        let ice_curves = array![[0.0, 0.4, 0.1], [0.2, 0.6, 0.3]]; // 2 instances, 3 points
1060        let config = PlotConfig::default();
1061
1062        let plot = create_partial_dependence_plot(
1063            &feature_values.view(),
1064            &pd_values.view(),
1065            Some(&ice_curves.view()),
1066            "feature_1",
1067            &config,
1068            true,
1069        )
1070        .unwrap();
1071
1072        assert_eq!(plot.feature_name, "feature_1");
1073        assert!(plot.show_ice);
1074        assert!(plot.ice_curves.is_some());
1075        assert_eq!(plot.ice_curves.as_ref().unwrap().shape(), &[2, 3]);
1076    }
1077
1078    #[test]
1079    fn test_partial_dependence_plot_dimension_mismatch() {
1080        let feature_values = array![0.0, 0.5, 1.0];
1081        let pd_values = array![0.1, 0.5]; // wrong length
1082        let config = PlotConfig::default();
1083
1084        let result = create_partial_dependence_plot(
1085            &feature_values.view(),
1086            &pd_values.view(),
1087            None,
1088            "feature_1",
1089            &config,
1090            false,
1091        );
1092        assert!(result.is_err());
1093    }
1094
1095    #[test]
1096    fn test_partial_dependence_plot_unsorted_features() {
1097        let feature_values = array![0.0, 1.0, 0.5]; // unsorted
1098        let pd_values = array![0.1, 0.2, 0.3];
1099        let config = PlotConfig::default();
1100
1101        let result = create_partial_dependence_plot(
1102            &feature_values.view(),
1103            &pd_values.view(),
1104            None,
1105            "feature_1",
1106            &config,
1107            false,
1108        );
1109        assert!(result.is_err());
1110    }
1111
1112    #[test]
1113    fn test_2d_partial_dependence_plot() {
1114        let feature1_values = array![0.0, 0.5, 1.0];
1115        let feature2_values = array![0.0, 1.0];
1116        let pd_surface = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]; // 3x2
1117        let config = PlotConfig::default();
1118
1119        let plot = create_2d_partial_dependence_plot(
1120            &feature1_values.view(),
1121            &feature2_values.view(),
1122            &pd_surface.view(),
1123            "feature_1",
1124            "feature_2",
1125            &config,
1126        )
1127        .unwrap();
1128
1129        assert_eq!(plot.model_data.len(), 1);
1130        assert!(plot.model_data.contains_key("2D_PD_Surface"));
1131        assert_eq!(plot.labels.len(), 2);
1132        assert_eq!(plot.comparison_type, ComparisonType::Heatmap);
1133    }
1134
1135    // Comparative Plot Tests
1136    #[test]
1137    fn test_comparative_plot_creation() {
1138        let mut model_data = HashMap::new();
1139        model_data.insert("model_1".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
1140        model_data.insert("model_2".to_string(), array![[2.0, 3.0], [4.0, 5.0]]);
1141
1142        let labels = vec!["Feature A".to_string(), "Feature B".to_string()];
1143        let config = PlotConfig::default();
1144
1145        let plot = create_comparative_plot(model_data, labels, &config, ComparisonType::SideBySide)
1146            .unwrap();
1147
1148        assert_eq!(plot.model_data.len(), 2);
1149        assert_eq!(plot.labels.len(), 2);
1150        assert_eq!(plot.comparison_type, ComparisonType::SideBySide);
1151    }
1152
1153    #[test]
1154    fn test_comparative_plot_empty_data() {
1155        let model_data = HashMap::new();
1156        let labels = vec!["Feature A".to_string()];
1157        let config = PlotConfig::default();
1158
1159        let result =
1160            create_comparative_plot(model_data, labels, &config, ComparisonType::SideBySide);
1161        assert!(result.is_err());
1162    }
1163
1164    #[test]
1165    fn test_comparative_plot_shape_mismatch() {
1166        let mut model_data = HashMap::new();
1167        model_data.insert("model_1".to_string(), array![[1.0, 2.0], [3.0, 4.0]]); // 2x2
1168        model_data.insert("model_2".to_string(), array![[2.0, 3.0, 5.0]]); // 1x3, different shape
1169
1170        let labels = vec!["Feature A".to_string(), "Feature B".to_string()];
1171        let config = PlotConfig::default();
1172
1173        let result =
1174            create_comparative_plot(model_data, labels, &config, ComparisonType::SideBySide);
1175        assert!(result.is_err());
1176    }
1177
1178    #[test]
1179    fn test_performance_comparison_plot() {
1180        let mut performance_data = HashMap::new();
1181        performance_data.insert("model_1".to_string(), array![0.85, 0.78, 0.92]);
1182        performance_data.insert("model_2".to_string(), array![0.83, 0.80, 0.89]);
1183
1184        let metric_names = vec![
1185            "Accuracy".to_string(),
1186            "Precision".to_string(),
1187            "Recall".to_string(),
1188        ];
1189        let config = PlotConfig::default();
1190
1191        let plot = create_performance_comparison_plot(
1192            performance_data,
1193            metric_names,
1194            None,
1195            &config,
1196            false,
1197        )
1198        .unwrap();
1199
1200        assert_eq!(plot.model_data.len(), 2);
1201        assert_eq!(plot.labels.len(), 3);
1202        assert_eq!(plot.comparison_type, ComparisonType::SideBySide);
1203    }
1204
1205    #[test]
1206    fn test_performance_comparison_with_significance() {
1207        let mut performance_data = HashMap::new();
1208        performance_data.insert("model_1".to_string(), array![0.85, 0.78]);
1209
1210        let metric_names = vec!["Accuracy".to_string(), "Precision".to_string()];
1211        let config = PlotConfig::default();
1212
1213        let plot = create_performance_comparison_plot(
1214            performance_data,
1215            metric_names,
1216            None,
1217            &config,
1218            true, // show significance
1219        )
1220        .unwrap();
1221
1222        assert_eq!(plot.comparison_type, ComparisonType::Statistical);
1223    }
1224
1225    // Edge Case Tests
1226    #[test]
1227    fn test_all_plot_types_enum_coverage() {
1228        // Test that we can create plots with all enum variants
1229        let importance = array![0.5];
1230        let config = PlotConfig::default();
1231
1232        for &plot_type in &[
1233            FeatureImportanceType::Bar,
1234            FeatureImportanceType::Horizontal,
1235            FeatureImportanceType::Radial,
1236            FeatureImportanceType::TreeMap,
1237        ] {
1238            let result =
1239                create_feature_importance_plot(&importance.view(), None, None, &config, plot_type);
1240            assert!(result.is_ok());
1241        }
1242
1243        let shap_values = array![[0.1]];
1244        let feature_values = array![[1.0]];
1245
1246        for &plot_type in &[
1247            ShapPlotType::Waterfall,
1248            ShapPlotType::ForceLayout,
1249            ShapPlotType::Summary,
1250            ShapPlotType::Dependence,
1251            ShapPlotType::Beeswarm,
1252            ShapPlotType::DecisionPlot,
1253        ] {
1254            let result = create_shap_visualization(
1255                &shap_values.view(),
1256                &feature_values.view(),
1257                None,
1258                None,
1259                &config,
1260                plot_type,
1261            );
1262            assert!(result.is_ok());
1263        }
1264    }
1265}