1use crate::{Float, SklResult};
51use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
53use std::collections::HashMap;
54
55use super::config_types::{
56 ComparativePlot, ComparisonType, FeatureImportancePlot, FeatureImportanceType,
57 PartialDependencePlot, PlotConfig, ShapPlot, ShapPlotType,
58};
59
60pub fn create_feature_importance_plot(
110 importance_values: &ArrayView1<Float>,
111 feature_names: Option<&[String]>,
112 std_values: Option<&ArrayView1<Float>>,
113 config: &PlotConfig,
114 plot_type: FeatureImportanceType,
115) -> SklResult<FeatureImportancePlot> {
116 let n_features = importance_values.len();
117
118 if n_features == 0 {
120 return Err(crate::SklearsError::InvalidInput(
121 "Importance values cannot be empty".to_string(),
122 ));
123 }
124
125 let feature_names = if let Some(names) = feature_names {
127 if names.len() != n_features {
128 return Err(crate::SklearsError::InvalidInput(format!(
129 "Feature names length ({}) does not match importance values length ({})",
130 names.len(),
131 n_features
132 )));
133 }
134 names.to_vec()
135 } else {
136 (0..n_features).map(|i| format!("Feature_{}", i)).collect()
137 };
138
139 if let Some(std) = std_values {
141 if std.len() != n_features {
142 return Err(crate::SklearsError::InvalidInput(format!(
143 "Standard deviation values length ({}) does not match importance values length ({})",
144 std.len(),
145 n_features
146 )));
147 }
148
149 for (i, &val) in std.iter().enumerate() {
151 if val < 0.0 {
152 return Err(crate::SklearsError::InvalidInput(format!(
153 "Standard deviation at index {} is negative: {}",
154 i, val
155 )));
156 }
157 }
158 }
159
160 let std_values = std_values.map(|std| std.to_vec());
161 let importance_values = importance_values.to_vec();
162
163 Ok(FeatureImportancePlot {
164 feature_names,
165 importance_values,
166 std_values,
167 config: config.clone(),
168 plot_type,
169 })
170}
171
172pub fn create_ranked_feature_importance_plot(
191 importance_values: &ArrayView1<Float>,
192 feature_names: Option<&[String]>,
193 std_values: Option<&ArrayView1<Float>>,
194 config: &PlotConfig,
195 plot_type: FeatureImportanceType,
196 top_k: Option<usize>,
197 min_threshold: Option<Float>,
198) -> SklResult<FeatureImportancePlot> {
199 let n_features = importance_values.len();
200
201 if n_features == 0 {
202 return Err(crate::SklearsError::InvalidInput(
203 "Importance values cannot be empty".to_string(),
204 ));
205 }
206
207 let feature_names = if let Some(names) = feature_names {
209 if names.len() != n_features {
210 return Err(crate::SklearsError::InvalidInput(format!(
211 "Feature names length ({}) does not match importance values length ({})",
212 names.len(),
213 n_features
214 )));
215 }
216 names.to_vec()
217 } else {
218 (0..n_features).map(|i| format!("Feature_{}", i)).collect()
219 };
220
221 let mut indices: Vec<usize> = (0..n_features).collect();
223 indices.sort_by(|&a, &b| {
224 importance_values[b]
225 .partial_cmp(&importance_values[a])
226 .unwrap_or(std::cmp::Ordering::Equal)
227 });
228
229 if let Some(threshold) = min_threshold {
231 indices.retain(|&i| importance_values[i].abs() >= threshold);
232 }
233
234 if let Some(k) = top_k {
236 indices.truncate(k);
237 }
238
239 if indices.is_empty() {
240 return Err(crate::SklearsError::InvalidInput(
241 "No features meet the filtering criteria".to_string(),
242 ));
243 }
244
245 let filtered_names = indices.iter().map(|&i| feature_names[i].clone()).collect();
247 let filtered_importance = indices.iter().map(|&i| importance_values[i]).collect();
248 let filtered_std = std_values.map(|std| indices.iter().map(|&i| std[i]).collect());
249
250 Ok(FeatureImportancePlot {
251 feature_names: filtered_names,
252 importance_values: filtered_importance,
253 std_values: filtered_std,
254 config: config.clone(),
255 plot_type,
256 })
257}
258
259pub fn create_shap_visualization(
311 shap_values: &ArrayView2<Float>,
312 feature_values: &ArrayView2<Float>,
313 feature_names: Option<&[String]>,
314 instance_names: Option<&[String]>,
315 config: &PlotConfig,
316 plot_type: ShapPlotType,
317) -> SklResult<ShapPlot> {
318 let (n_instances, n_features) = shap_values.dim();
319
320 if n_instances == 0 || n_features == 0 {
322 return Err(crate::SklearsError::InvalidInput(
323 "SHAP values cannot have zero dimensions".to_string(),
324 ));
325 }
326
327 if feature_values.dim() != (n_instances, n_features) {
328 return Err(crate::SklearsError::InvalidInput(format!(
329 "SHAP values shape {:?} and feature values shape {:?} do not match",
330 (n_instances, n_features),
331 feature_values.dim()
332 )));
333 }
334
335 let feature_names = if let Some(names) = feature_names {
337 if names.len() != n_features {
338 return Err(crate::SklearsError::InvalidInput(format!(
339 "Feature names length ({}) does not match number of features ({})",
340 names.len(),
341 n_features
342 )));
343 }
344 names.to_vec()
345 } else {
346 (0..n_features).map(|i| format!("Feature_{}", i)).collect()
347 };
348
349 let instance_names = if let Some(names) = instance_names {
351 if names.len() != n_instances {
352 return Err(crate::SklearsError::InvalidInput(format!(
353 "Instance names length ({}) does not match number of instances ({})",
354 names.len(),
355 n_instances
356 )));
357 }
358 names.to_vec()
359 } else {
360 (0..n_instances)
361 .map(|i| format!("Instance_{}", i))
362 .collect()
363 };
364
365 Ok(ShapPlot {
366 shap_values: shap_values.to_owned(),
367 feature_values: feature_values.to_owned(),
368 feature_names,
369 instance_names,
370 config: config.clone(),
371 plot_type,
372 })
373}
374
375pub fn create_shap_summary_plot(
392 shap_values: &ArrayView2<Float>,
393 feature_values: &ArrayView2<Float>,
394 feature_names: Option<&[String]>,
395 config: &PlotConfig,
396 show_distribution: bool,
397) -> SklResult<ShapPlot> {
398 let (n_instances, n_features) = shap_values.dim();
399
400 if n_instances == 0 || n_features == 0 {
401 return Err(crate::SklearsError::InvalidInput(
402 "SHAP values cannot have zero dimensions".to_string(),
403 ));
404 }
405
406 if feature_values.dim() != (n_instances, n_features) {
407 return Err(crate::SklearsError::InvalidInput(
408 "SHAP values and feature values dimensions do not match".to_string(),
409 ));
410 }
411
412 let feature_names = if let Some(names) = feature_names {
413 if names.len() != n_features {
414 return Err(crate::SklearsError::InvalidInput(
415 "Feature names length does not match number of features".to_string(),
416 ));
417 }
418 names.to_vec()
419 } else {
420 (0..n_features).map(|i| format!("Feature_{}", i)).collect()
421 };
422
423 let instance_names = (0..n_instances)
424 .map(|i| format!("Instance_{}", i))
425 .collect();
426
427 let plot_type = if show_distribution {
428 ShapPlotType::Beeswarm
429 } else {
430 ShapPlotType::Summary
431 };
432
433 Ok(ShapPlot {
434 shap_values: shap_values.to_owned(),
435 feature_values: feature_values.to_owned(),
436 feature_names,
437 instance_names,
438 config: config.clone(),
439 plot_type,
440 })
441}
442
443pub fn create_partial_dependence_plot(
496 feature_values: &ArrayView1<Float>,
497 pd_values: &ArrayView1<Float>,
498 ice_curves: Option<&ArrayView2<Float>>,
499 feature_name: &str,
500 config: &PlotConfig,
501 show_ice: bool,
502) -> SklResult<PartialDependencePlot> {
503 let n_points = feature_values.len();
504
505 if n_points == 0 {
507 return Err(crate::SklearsError::InvalidInput(
508 "Feature values cannot be empty".to_string(),
509 ));
510 }
511
512 if pd_values.len() != n_points {
513 return Err(crate::SklearsError::InvalidInput(format!(
514 "Feature values length ({}) and PD values length ({}) must match",
515 n_points,
516 pd_values.len()
517 )));
518 }
519
520 if let Some(ice) = ice_curves {
522 if ice.ncols() != n_points {
523 return Err(crate::SklearsError::InvalidInput(format!(
524 "ICE curves columns ({}) must match feature values length ({})",
525 ice.ncols(),
526 n_points
527 )));
528 }
529
530 if ice.nrows() == 0 {
531 return Err(crate::SklearsError::InvalidInput(
532 "ICE curves cannot have zero instances".to_string(),
533 ));
534 }
535 }
536
537 for i in 1..n_points {
539 if feature_values[i] < feature_values[i - 1] {
540 return Err(crate::SklearsError::InvalidInput(
541 "Feature values must be sorted in ascending order for proper PD interpretation"
542 .to_string(),
543 ));
544 }
545 }
546
547 Ok(PartialDependencePlot {
548 feature_values: feature_values.to_owned(),
549 pd_values: pd_values.to_owned(),
550 ice_curves: ice_curves.map(|ice| ice.to_owned()),
551 feature_name: feature_name.to_string(),
552 config: config.clone(),
553 show_ice: show_ice && ice_curves.is_some(),
554 })
555}
556
557pub fn create_2d_partial_dependence_plot(
575 feature1_values: &ArrayView1<Float>,
576 feature2_values: &ArrayView1<Float>,
577 pd_surface: &ArrayView2<Float>,
578 feature1_name: &str,
579 feature2_name: &str,
580 config: &PlotConfig,
581) -> SklResult<ComparativePlot> {
582 let n_points1 = feature1_values.len();
583 let n_points2 = feature2_values.len();
584
585 if n_points1 == 0 || n_points2 == 0 {
586 return Err(crate::SklearsError::InvalidInput(
587 "Feature values cannot be empty".to_string(),
588 ));
589 }
590
591 if pd_surface.dim() != (n_points1, n_points2) {
592 return Err(crate::SklearsError::InvalidInput(format!(
593 "PD surface shape {:?} does not match expected shape ({}, {})",
594 pd_surface.dim(),
595 n_points1,
596 n_points2
597 )));
598 }
599
600 let mut model_data = HashMap::new();
601 model_data.insert("2D_PD_Surface".to_string(), pd_surface.to_owned());
602
603 let labels = vec![feature1_name.to_string(), feature2_name.to_string()];
604
605 Ok(ComparativePlot {
606 model_data,
607 labels,
608 config: config.clone(),
609 comparison_type: ComparisonType::Heatmap,
610 })
611}
612
613pub fn create_comparative_plot(
664 model_data: HashMap<String, Array2<Float>>,
665 labels: Vec<String>,
666 config: &PlotConfig,
667 comparison_type: ComparisonType,
668) -> SklResult<ComparativePlot> {
669 if model_data.is_empty() {
671 return Err(crate::SklearsError::InvalidInput(
672 "Model data cannot be empty".to_string(),
673 ));
674 }
675
676 if labels.is_empty() {
678 return Err(crate::SklearsError::InvalidInput(
679 "Labels cannot be empty".to_string(),
680 ));
681 }
682
683 let first_entry = model_data.iter().next().unwrap();
685 let (first_name, first_data) = first_entry;
686 let expected_shape = first_data.dim();
687
688 if expected_shape.0 == 0 || expected_shape.1 == 0 {
690 return Err(crate::SklearsError::InvalidInput(format!(
691 "Model '{}' has invalid data shape: {:?}",
692 first_name, expected_shape
693 )));
694 }
695
696 for (model_name, data) in &model_data {
698 let current_shape = data.dim();
699 if current_shape != expected_shape {
700 return Err(crate::SklearsError::InvalidInput(format!(
701 "Model '{}' data shape {:?} does not match expected shape {:?}",
702 model_name, current_shape, expected_shape
703 )));
704 }
705
706 for value in data.iter() {
708 if !value.is_finite() {
709 return Err(crate::SklearsError::InvalidInput(format!(
710 "Model '{}' contains non-finite values",
711 model_name
712 )));
713 }
714 }
715 }
716
717 if labels.len() != expected_shape.1 {
719 return Err(crate::SklearsError::InvalidInput(format!(
720 "Labels count ({}) does not match data columns ({})",
721 labels.len(),
722 expected_shape.1
723 )));
724 }
725
726 Ok(ComparativePlot {
727 model_data,
728 labels,
729 config: config.clone(),
730 comparison_type,
731 })
732}
733
734pub fn create_performance_comparison_plot(
751 performance_data: HashMap<String, Array1<Float>>,
752 metric_names: Vec<String>,
753 confidence_intervals: Option<HashMap<String, Array2<Float>>>,
754 config: &PlotConfig,
755 show_significance: bool,
756) -> SklResult<ComparativePlot> {
757 if performance_data.is_empty() {
758 return Err(crate::SklearsError::InvalidInput(
759 "Performance data cannot be empty".to_string(),
760 ));
761 }
762
763 if metric_names.is_empty() {
764 return Err(crate::SklearsError::InvalidInput(
765 "Metric names cannot be empty".to_string(),
766 ));
767 }
768
769 let mut model_data_2d = HashMap::new();
771 let expected_len = metric_names.len();
772
773 for (model_name, metrics) in performance_data {
774 if metrics.len() != expected_len {
775 return Err(crate::SklearsError::InvalidInput(format!(
776 "Model '{}' metrics length ({}) does not match expected length ({})",
777 model_name,
778 metrics.len(),
779 expected_len
780 )));
781 }
782
783 let metrics_2d = metrics.insert_axis(scirs2_core::ndarray::Axis(0));
785 model_data_2d.insert(model_name, metrics_2d);
786 }
787
788 if let Some(ci) = &confidence_intervals {
790 for (model_name, intervals) in ci {
791 if !model_data_2d.contains_key(model_name) {
792 return Err(crate::SklearsError::InvalidInput(format!(
793 "Confidence interval provided for unknown model: '{}'",
794 model_name
795 )));
796 }
797
798 if intervals.dim() != (2, expected_len) {
799 return Err(crate::SklearsError::InvalidInput(format!(
800 "Confidence intervals for model '{}' have incorrect shape: {:?}, expected (2, {})",
801 model_name,
802 intervals.dim(),
803 expected_len
804 )));
805 }
806 }
807 }
808
809 let comparison_type = if show_significance {
810 ComparisonType::Statistical
811 } else {
812 ComparisonType::SideBySide
813 };
814
815 Ok(ComparativePlot {
816 model_data: model_data_2d,
817 labels: metric_names,
818 config: config.clone(),
819 comparison_type,
820 })
821}
822
823#[cfg(test)]
828mod tests {
829 use super::*;
830 use scirs2_core::ndarray::array;
832
833 #[test]
835 fn test_feature_importance_plot_creation() {
836 let importance = array![0.3, 0.5, 0.2];
837 let features = vec![
838 "Feature1".to_string(),
839 "Feature2".to_string(),
840 "Feature3".to_string(),
841 ];
842 let config = PlotConfig::default();
843
844 let plot = create_feature_importance_plot(
845 &importance.view(),
846 Some(&features),
847 None,
848 &config,
849 FeatureImportanceType::Bar,
850 )
851 .unwrap();
852
853 assert_eq!(plot.feature_names.len(), 3);
854 assert_eq!(plot.importance_values.len(), 3);
855 assert_eq!(plot.importance_values[1], 0.5);
856 assert!(plot.std_values.is_none());
857 assert_eq!(plot.plot_type, FeatureImportanceType::Bar);
858 }
859
860 #[test]
861 fn test_feature_importance_with_std() {
862 let importance = array![0.3, 0.5, 0.2];
863 let std_vals = array![0.1, 0.05, 0.15];
864 let config = PlotConfig::default();
865
866 let plot = create_feature_importance_plot(
867 &importance.view(),
868 None,
869 Some(&std_vals.view()),
870 &config,
871 FeatureImportanceType::Horizontal,
872 )
873 .unwrap();
874
875 assert_eq!(plot.feature_names.len(), 3);
876 assert!(plot.std_values.is_some());
877 assert_eq!(plot.std_values.as_ref().unwrap().len(), 3);
878 assert_eq!(plot.plot_type, FeatureImportanceType::Horizontal);
879 }
880
881 #[test]
882 fn test_feature_importance_dimension_mismatch() {
883 let importance = array![0.3, 0.5];
884 let features = vec![
885 "Feature1".to_string(),
886 "Feature2".to_string(),
887 "Feature3".to_string(),
888 ];
889 let config = PlotConfig::default();
890
891 let result = create_feature_importance_plot(
892 &importance.view(),
893 Some(&features),
894 None,
895 &config,
896 FeatureImportanceType::Bar,
897 );
898 assert!(result.is_err());
899 }
900
901 #[test]
902 fn test_feature_importance_empty_input() {
903 let importance = array![];
904 let config = PlotConfig::default();
905
906 let result = create_feature_importance_plot(
907 &importance.view(),
908 None,
909 None,
910 &config,
911 FeatureImportanceType::Bar,
912 );
913 assert!(result.is_err());
914 }
915
916 #[test]
917 fn test_feature_importance_negative_std() {
918 let importance = array![0.3, 0.5, 0.2];
919 let std_vals = array![0.1, -0.05, 0.15]; let config = PlotConfig::default();
921
922 let result = create_feature_importance_plot(
923 &importance.view(),
924 None,
925 Some(&std_vals.view()),
926 &config,
927 FeatureImportanceType::Bar,
928 );
929 assert!(result.is_err());
930 }
931
932 #[test]
933 fn test_ranked_feature_importance_plot() {
934 let importance = array![0.1, 0.5, 0.3, 0.2];
935 let config = PlotConfig::default();
936
937 let plot = create_ranked_feature_importance_plot(
938 &importance.view(),
939 None,
940 None,
941 &config,
942 FeatureImportanceType::Bar,
943 Some(2), Some(0.15), )
946 .unwrap();
947
948 assert_eq!(plot.feature_names.len(), 2);
950 assert_eq!(plot.importance_values.len(), 2);
951 assert_eq!(plot.importance_values[0], 0.5); assert_eq!(plot.importance_values[1], 0.3); }
954
955 #[test]
957 fn test_shap_plot_creation() {
958 let shap_values = array![[0.1, 0.2, -0.1], [0.3, -0.1, 0.2]];
959 let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]];
960 let config = PlotConfig::default();
961
962 let plot = create_shap_visualization(
963 &shap_values.view(),
964 &feature_values.view(),
965 None,
966 None,
967 &config,
968 ShapPlotType::Summary,
969 )
970 .unwrap();
971
972 assert_eq!(plot.shap_values.shape(), &[2, 3]);
973 assert_eq!(plot.feature_names.len(), 3);
974 assert_eq!(plot.instance_names.len(), 2);
975 assert_eq!(plot.plot_type, ShapPlotType::Summary);
976 }
977
978 #[test]
979 fn test_shap_plot_dimension_mismatch() {
980 let shap_values = array![[0.1, 0.2], [0.3, -0.1]];
981 let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]];
982 let config = PlotConfig::default();
983
984 let result = create_shap_visualization(
985 &shap_values.view(),
986 &feature_values.view(),
987 None,
988 None,
989 &config,
990 ShapPlotType::Summary,
991 );
992 assert!(result.is_err());
993 }
994
995 #[test]
996 fn test_shap_plot_zero_dimensions() {
997 let shap_values = array![[], []]; let feature_values = array![[], []]; let config = PlotConfig::default();
1000
1001 let result = create_shap_visualization(
1002 &shap_values.view(),
1003 &feature_values.view(),
1004 None,
1005 None,
1006 &config,
1007 ShapPlotType::Summary,
1008 );
1009 assert!(result.is_err());
1010 }
1011
1012 #[test]
1013 fn test_shap_summary_plot() {
1014 let shap_values = array![[0.1, 0.2, -0.1], [0.3, -0.1, 0.2], [0.0, 0.1, -0.05]];
1015 let feature_values = array![[1.0, 2.0, 3.0], [1.5, 2.5, 3.5], [1.2, 2.1, 3.2]];
1016 let config = PlotConfig::default();
1017
1018 let plot = create_shap_summary_plot(
1019 &shap_values.view(),
1020 &feature_values.view(),
1021 None,
1022 &config,
1023 true, )
1025 .unwrap();
1026
1027 assert_eq!(plot.shap_values.shape(), &[3, 3]);
1028 assert_eq!(plot.plot_type, ShapPlotType::Beeswarm);
1029 }
1030
1031 #[test]
1033 fn test_partial_dependence_plot_creation() {
1034 let feature_values = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
1035 let pd_values = array![0.1, 0.3, 0.5, 0.4, 0.2, 0.1];
1036 let config = PlotConfig::default();
1037
1038 let plot = create_partial_dependence_plot(
1039 &feature_values.view(),
1040 &pd_values.view(),
1041 None,
1042 "feature_1",
1043 &config,
1044 false,
1045 )
1046 .unwrap();
1047
1048 assert_eq!(plot.feature_name, "feature_1");
1049 assert_eq!(plot.feature_values.len(), 6);
1050 assert_eq!(plot.pd_values.len(), 6);
1051 assert!(!plot.show_ice);
1052 assert!(plot.ice_curves.is_none());
1053 }
1054
1055 #[test]
1056 fn test_partial_dependence_plot_with_ice() {
1057 let feature_values = array![0.0, 0.5, 1.0];
1058 let pd_values = array![0.1, 0.5, 0.2];
1059 let ice_curves = array![[0.0, 0.4, 0.1], [0.2, 0.6, 0.3]]; let config = PlotConfig::default();
1061
1062 let plot = create_partial_dependence_plot(
1063 &feature_values.view(),
1064 &pd_values.view(),
1065 Some(&ice_curves.view()),
1066 "feature_1",
1067 &config,
1068 true,
1069 )
1070 .unwrap();
1071
1072 assert_eq!(plot.feature_name, "feature_1");
1073 assert!(plot.show_ice);
1074 assert!(plot.ice_curves.is_some());
1075 assert_eq!(plot.ice_curves.as_ref().unwrap().shape(), &[2, 3]);
1076 }
1077
1078 #[test]
1079 fn test_partial_dependence_plot_dimension_mismatch() {
1080 let feature_values = array![0.0, 0.5, 1.0];
1081 let pd_values = array![0.1, 0.5]; let config = PlotConfig::default();
1083
1084 let result = create_partial_dependence_plot(
1085 &feature_values.view(),
1086 &pd_values.view(),
1087 None,
1088 "feature_1",
1089 &config,
1090 false,
1091 );
1092 assert!(result.is_err());
1093 }
1094
1095 #[test]
1096 fn test_partial_dependence_plot_unsorted_features() {
1097 let feature_values = array![0.0, 1.0, 0.5]; let pd_values = array![0.1, 0.2, 0.3];
1099 let config = PlotConfig::default();
1100
1101 let result = create_partial_dependence_plot(
1102 &feature_values.view(),
1103 &pd_values.view(),
1104 None,
1105 "feature_1",
1106 &config,
1107 false,
1108 );
1109 assert!(result.is_err());
1110 }
1111
1112 #[test]
1113 fn test_2d_partial_dependence_plot() {
1114 let feature1_values = array![0.0, 0.5, 1.0];
1115 let feature2_values = array![0.0, 1.0];
1116 let pd_surface = array![[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]; let config = PlotConfig::default();
1118
1119 let plot = create_2d_partial_dependence_plot(
1120 &feature1_values.view(),
1121 &feature2_values.view(),
1122 &pd_surface.view(),
1123 "feature_1",
1124 "feature_2",
1125 &config,
1126 )
1127 .unwrap();
1128
1129 assert_eq!(plot.model_data.len(), 1);
1130 assert!(plot.model_data.contains_key("2D_PD_Surface"));
1131 assert_eq!(plot.labels.len(), 2);
1132 assert_eq!(plot.comparison_type, ComparisonType::Heatmap);
1133 }
1134
1135 #[test]
1137 fn test_comparative_plot_creation() {
1138 let mut model_data = HashMap::new();
1139 model_data.insert("model_1".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
1140 model_data.insert("model_2".to_string(), array![[2.0, 3.0], [4.0, 5.0]]);
1141
1142 let labels = vec!["Feature A".to_string(), "Feature B".to_string()];
1143 let config = PlotConfig::default();
1144
1145 let plot = create_comparative_plot(model_data, labels, &config, ComparisonType::SideBySide)
1146 .unwrap();
1147
1148 assert_eq!(plot.model_data.len(), 2);
1149 assert_eq!(plot.labels.len(), 2);
1150 assert_eq!(plot.comparison_type, ComparisonType::SideBySide);
1151 }
1152
1153 #[test]
1154 fn test_comparative_plot_empty_data() {
1155 let model_data = HashMap::new();
1156 let labels = vec!["Feature A".to_string()];
1157 let config = PlotConfig::default();
1158
1159 let result =
1160 create_comparative_plot(model_data, labels, &config, ComparisonType::SideBySide);
1161 assert!(result.is_err());
1162 }
1163
1164 #[test]
1165 fn test_comparative_plot_shape_mismatch() {
1166 let mut model_data = HashMap::new();
1167 model_data.insert("model_1".to_string(), array![[1.0, 2.0], [3.0, 4.0]]); model_data.insert("model_2".to_string(), array![[2.0, 3.0, 5.0]]); let labels = vec!["Feature A".to_string(), "Feature B".to_string()];
1171 let config = PlotConfig::default();
1172
1173 let result =
1174 create_comparative_plot(model_data, labels, &config, ComparisonType::SideBySide);
1175 assert!(result.is_err());
1176 }
1177
1178 #[test]
1179 fn test_performance_comparison_plot() {
1180 let mut performance_data = HashMap::new();
1181 performance_data.insert("model_1".to_string(), array![0.85, 0.78, 0.92]);
1182 performance_data.insert("model_2".to_string(), array![0.83, 0.80, 0.89]);
1183
1184 let metric_names = vec![
1185 "Accuracy".to_string(),
1186 "Precision".to_string(),
1187 "Recall".to_string(),
1188 ];
1189 let config = PlotConfig::default();
1190
1191 let plot = create_performance_comparison_plot(
1192 performance_data,
1193 metric_names,
1194 None,
1195 &config,
1196 false,
1197 )
1198 .unwrap();
1199
1200 assert_eq!(plot.model_data.len(), 2);
1201 assert_eq!(plot.labels.len(), 3);
1202 assert_eq!(plot.comparison_type, ComparisonType::SideBySide);
1203 }
1204
1205 #[test]
1206 fn test_performance_comparison_with_significance() {
1207 let mut performance_data = HashMap::new();
1208 performance_data.insert("model_1".to_string(), array![0.85, 0.78]);
1209
1210 let metric_names = vec!["Accuracy".to_string(), "Precision".to_string()];
1211 let config = PlotConfig::default();
1212
1213 let plot = create_performance_comparison_plot(
1214 performance_data,
1215 metric_names,
1216 None,
1217 &config,
1218 true, )
1220 .unwrap();
1221
1222 assert_eq!(plot.comparison_type, ComparisonType::Statistical);
1223 }
1224
1225 #[test]
1227 fn test_all_plot_types_enum_coverage() {
1228 let importance = array![0.5];
1230 let config = PlotConfig::default();
1231
1232 for &plot_type in &[
1233 FeatureImportanceType::Bar,
1234 FeatureImportanceType::Horizontal,
1235 FeatureImportanceType::Radial,
1236 FeatureImportanceType::TreeMap,
1237 ] {
1238 let result =
1239 create_feature_importance_plot(&importance.view(), None, None, &config, plot_type);
1240 assert!(result.is_ok());
1241 }
1242
1243 let shap_values = array![[0.1]];
1244 let feature_values = array![[1.0]];
1245
1246 for &plot_type in &[
1247 ShapPlotType::Waterfall,
1248 ShapPlotType::ForceLayout,
1249 ShapPlotType::Summary,
1250 ShapPlotType::Dependence,
1251 ShapPlotType::Beeswarm,
1252 ShapPlotType::DecisionPlot,
1253 ] {
1254 let result = create_shap_visualization(
1255 &shap_values.view(),
1256 &feature_values.view(),
1257 None,
1258 None,
1259 &config,
1260 plot_type,
1261 );
1262 assert!(result.is_ok());
1263 }
1264 }
1265}