1use crate::{Float, SklResult};
7use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14pub trait VisualizationBackend: Debug + Send + Sync {
16 fn render_feature_importance(
18 &self,
19 data: &FeatureImportanceData,
20 config: &BackendConfig,
21 ) -> SklResult<RenderedVisualization>;
22
23 fn render_shap_plot(
25 &self,
26 data: &ShapData,
27 config: &BackendConfig,
28 ) -> SklResult<RenderedVisualization>;
29
30 fn render_partial_dependence(
32 &self,
33 data: &PartialDependenceData,
34 config: &BackendConfig,
35 ) -> SklResult<RenderedVisualization>;
36
37 fn render_comparative_plot(
39 &self,
40 data: &ComparativeData,
41 config: &BackendConfig,
42 ) -> SklResult<RenderedVisualization>;
43
44 fn render_custom_plot(
46 &self,
47 data: &CustomPlotData,
48 config: &BackendConfig,
49 ) -> SklResult<RenderedVisualization>;
50
51 fn name(&self) -> &str;
53
54 fn supported_formats(&self) -> Vec<OutputFormat>;
56
57 fn supports_interactivity(&self) -> bool;
59
60 fn capabilities(&self) -> BackendCapabilities;
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct BackendConfig {
67 pub format: OutputFormat,
69 pub width: usize,
71 pub height: usize,
73 pub dpi: usize,
75 pub interactive: bool,
77 pub color_scheme: ColorScheme,
79 pub theme: Theme,
81 pub custom_properties: HashMap<String, String>,
83}
84
85impl Default for BackendConfig {
86 fn default() -> Self {
87 Self {
88 format: OutputFormat::Html,
89 width: 800,
90 height: 600,
91 dpi: 96,
92 interactive: true,
93 color_scheme: ColorScheme::Default,
94 theme: Theme::Light,
95 custom_properties: HashMap::new(),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
102pub enum OutputFormat {
103 Html,
105 Json,
107 Svg,
109 Png,
111 Jpeg,
113 Pdf,
115 Ascii,
117 Unicode,
119}
120
121#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
123pub enum ColorScheme {
124 Default,
126 Viridis,
128 Plasma,
130 Magma,
132 Inferno,
134 Blues,
136 Reds,
138 Greens,
140 Categorical,
142 Diverging,
144}
145
146#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
148pub enum Theme {
149 Light,
151 Dark,
153 HighContrast,
155 Minimal,
157 Scientific,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct BackendCapabilities {
164 pub formats: Vec<OutputFormat>,
166 pub interactive: bool,
168 pub animations: bool,
170 pub three_d: bool,
172 pub custom_themes: bool,
174 pub real_time_updates: bool,
176 pub max_data_points: Option<usize>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct RenderedVisualization {
183 pub content: String,
185 pub format: OutputFormat,
187 pub metadata: VisualizationMetadata,
189 pub binary_data: Option<Vec<u8>>,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct VisualizationMetadata {
196 pub backend: String,
198 pub render_time_ms: u64,
200 pub file_size_bytes: usize,
202 pub data_points: usize,
204 pub created_at: chrono::DateTime<chrono::Utc>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct FeatureImportanceData {
211 pub feature_names: Vec<String>,
213 pub importance_values: Vec<Float>,
215 pub std_values: Option<Vec<Float>>,
217 pub plot_type: FeatureImportanceType,
219 pub title: String,
221 pub x_label: String,
223 pub y_label: String,
225}
226
227#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
229pub enum FeatureImportanceType {
230 Bar,
232 Horizontal,
234 Radial,
236 TreeMap,
238 Waterfall,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct ShapData {
245 pub shap_values: Array2<Float>,
247 pub feature_values: Array2<Float>,
249 pub feature_names: Vec<String>,
251 pub instance_names: Vec<String>,
253 pub plot_type: ShapPlotType,
255 pub title: String,
257}
258
259#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
261pub enum ShapPlotType {
262 Waterfall,
264 ForceLayout,
266 Summary,
268 Dependence,
270 Beeswarm,
272 DecisionPlot,
274 Violin,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct PartialDependenceData {
281 pub feature_values: Array1<Float>,
283 pub pd_values: Array1<Float>,
285 pub ice_curves: Option<Array2<Float>>,
287 pub feature_name: String,
289 pub title: String,
291 pub show_ice: bool,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ComparativeData {
298 pub model_data: HashMap<String, Array2<Float>>,
300 pub labels: Vec<String>,
302 pub comparison_type: ComparisonType,
304 pub title: String,
306}
307
308#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
310pub enum ComparisonType {
311 SideBySide,
313 Overlay,
315 Difference,
317 Ratio,
319 Ranking,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct CustomPlotData {
326 pub data: serde_json::Value,
328 pub plot_type: String,
330 pub title: String,
332 pub metadata: HashMap<String, String>,
334}
335
336#[derive(Debug, Default)]
338pub struct BackendRegistry {
339 backends: HashMap<String, Arc<dyn VisualizationBackend>>,
340 default_backend: Option<String>,
341}
342
343impl BackendRegistry {
344 pub fn new() -> Self {
346 Self::default()
347 }
348
349 pub fn register_backend<B: VisualizationBackend + 'static>(&mut self, backend: B) {
351 let name = backend.name().to_string();
352 self.backends.insert(name.clone(), Arc::new(backend));
353
354 if self.default_backend.is_none() {
356 self.default_backend = Some(name);
357 }
358 }
359
360 pub fn get_backend(&self, name: &str) -> Option<Arc<dyn VisualizationBackend>> {
362 self.backends.get(name).cloned()
363 }
364
365 pub fn get_default_backend(&self) -> Option<Arc<dyn VisualizationBackend>> {
367 self.default_backend
368 .as_ref()
369 .and_then(|name| self.backends.get(name).cloned())
370 }
371
372 pub fn set_default_backend(&mut self, name: &str) -> SklResult<()> {
374 if self.backends.contains_key(name) {
375 self.default_backend = Some(name.to_string());
376 Ok(())
377 } else {
378 Err(crate::SklearsError::InvalidInput(format!(
379 "Backend '{}' not found",
380 name
381 )))
382 }
383 }
384
385 pub fn list_backends(&self) -> Vec<String> {
387 self.backends.keys().cloned().collect()
388 }
389
390 pub fn get_capabilities(&self, name: &str) -> Option<BackendCapabilities> {
392 self.backends.get(name).map(|b| b.capabilities())
393 }
394
395 pub fn find_backends_for_format(&self, format: OutputFormat) -> Vec<String> {
397 self.backends
398 .iter()
399 .filter(|(_, backend)| backend.supported_formats().contains(&format))
400 .map(|(name, _)| name.clone())
401 .collect()
402 }
403}
404
405#[derive(Debug)]
407pub struct VisualizationRenderer {
408 registry: BackendRegistry,
409}
410
411impl VisualizationRenderer {
412 pub fn new() -> Self {
414 Self {
415 registry: BackendRegistry::new(),
416 }
417 }
418
419 pub fn with_default_backends() -> Self {
421 let mut renderer = Self::new();
422 renderer.register_default_backends();
423 renderer
424 }
425
426 pub fn register_default_backends(&mut self) {
428 self.registry.register_backend(HtmlBackend::new());
429 self.registry.register_backend(JsonBackend::new());
430 self.registry.register_backend(AsciiBackend::new());
431 }
432
433 pub fn register_backend<B: VisualizationBackend + 'static>(&mut self, backend: B) {
435 self.registry.register_backend(backend);
436 }
437
438 pub fn render_with_backend(
440 &self,
441 backend_name: &str,
442 plot_type: PlotType,
443 config: &BackendConfig,
444 ) -> SklResult<RenderedVisualization> {
445 let backend = self.registry.get_backend(backend_name).ok_or_else(|| {
446 crate::SklearsError::InvalidInput(format!("Backend '{}' not found", backend_name))
447 })?;
448
449 match plot_type {
450 PlotType::FeatureImportance(data) => backend.render_feature_importance(&data, config),
451 PlotType::Shap(data) => backend.render_shap_plot(&data, config),
452 PlotType::PartialDependence(data) => backend.render_partial_dependence(&data, config),
453 PlotType::Comparative(data) => backend.render_comparative_plot(&data, config),
454 PlotType::Custom(data) => backend.render_custom_plot(&data, config),
455 }
456 }
457
458 pub fn render(
460 &self,
461 plot_type: PlotType,
462 config: &BackendConfig,
463 ) -> SklResult<RenderedVisualization> {
464 let backend = self.registry.get_default_backend().ok_or_else(|| {
465 crate::SklearsError::InvalidInput("No default backend available".to_string())
466 })?;
467
468 match plot_type {
469 PlotType::FeatureImportance(data) => backend.render_feature_importance(&data, config),
470 PlotType::Shap(data) => backend.render_shap_plot(&data, config),
471 PlotType::PartialDependence(data) => backend.render_partial_dependence(&data, config),
472 PlotType::Comparative(data) => backend.render_comparative_plot(&data, config),
473 PlotType::Custom(data) => backend.render_custom_plot(&data, config),
474 }
475 }
476
477 pub fn registry(&self) -> &BackendRegistry {
479 &self.registry
480 }
481
482 pub fn registry_mut(&mut self) -> &mut BackendRegistry {
484 &mut self.registry
485 }
486}
487
488impl Default for VisualizationRenderer {
489 fn default() -> Self {
490 Self::with_default_backends()
491 }
492}
493
494#[derive(Debug, Clone)]
496pub enum PlotType {
497 FeatureImportance(FeatureImportanceData),
499 Shap(ShapData),
501 PartialDependence(PartialDependenceData),
503 Comparative(ComparativeData),
505 Custom(CustomPlotData),
507}
508
509#[derive(Debug)]
511pub struct HtmlBackend {
512 name: String,
513}
514
515impl Default for HtmlBackend {
516 fn default() -> Self {
517 Self::new()
518 }
519}
520
521impl HtmlBackend {
522 pub fn new() -> Self {
524 Self {
525 name: "html".to_string(),
526 }
527 }
528
529 fn generate_html_template(&self, title: &str, content: &str) -> String {
531 format!(
532 r#"<!DOCTYPE html>
533<html>
534<head>
535 <title>{}</title>
536 <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
537 <style>
538 body {{ font-family: Arial, sans-serif; margin: 20px; }}
539 .plot-container {{ margin: 20px 0; }}
540 .plot-title {{ font-size: 18px; font-weight: bold; margin-bottom: 10px; }}
541 </style>
542</head>
543<body>
544 <div class="plot-container">
545 <div class="plot-title">{}</div>
546 <div id="plot">{}</div>
547 </div>
548</body>
549</html>"#,
550 title, title, content
551 )
552 }
553}
554
555impl VisualizationBackend for HtmlBackend {
556 fn render_feature_importance(
557 &self,
558 data: &FeatureImportanceData,
559 config: &BackendConfig,
560 ) -> SklResult<RenderedVisualization> {
561 let start_time = std::time::Instant::now();
562
563 let plot_data = match data.plot_type {
565 FeatureImportanceType::Bar => {
566 let x_data: Vec<String> = data
567 .feature_names
568 .iter()
569 .map(|name| format!("\"{}\"", name))
570 .collect();
571 let y_data: Vec<String> = data
572 .importance_values
573 .iter()
574 .map(|val| val.to_string())
575 .collect();
576
577 format!(
578 r#"
579 var data = [{{
580 x: [{}],
581 y: [{}],
582 type: 'bar',
583 marker: {{
584 color: '#1f77b4'
585 }}
586 }}];
587
588 var layout = {{
589 title: '{}',
590 xaxis: {{ title: '{}' }},
591 yaxis: {{ title: '{}' }},
592 width: {},
593 height: {}
594 }};
595
596 Plotly.newPlot('plot', data, layout);
597 "#,
598 x_data.join(", "),
599 y_data.join(", "),
600 data.title,
601 data.x_label,
602 data.y_label,
603 config.width,
604 config.height
605 )
606 }
607 _ => {
608 let x_data: Vec<String> = data
610 .feature_names
611 .iter()
612 .map(|name| format!("\"{}\"", name))
613 .collect();
614 let y_data: Vec<String> = data
615 .importance_values
616 .iter()
617 .map(|val| val.to_string())
618 .collect();
619
620 format!(
621 r#"
622 var data = [{{
623 x: [{}],
624 y: [{}],
625 type: 'bar'
626 }}];
627
628 var layout = {{
629 title: '{}',
630 width: {},
631 height: {}
632 }};
633
634 Plotly.newPlot('plot', data, layout);
635 "#,
636 x_data.join(", "),
637 y_data.join(", "),
638 data.title,
639 config.width,
640 config.height
641 )
642 }
643 };
644
645 let html_content = self.generate_html_template(&data.title, &plot_data);
646 let render_time = start_time.elapsed().as_millis() as u64;
647
648 Ok(RenderedVisualization {
649 content: html_content.clone(),
650 format: OutputFormat::Html,
651 metadata: VisualizationMetadata {
652 backend: self.name.clone(),
653 render_time_ms: render_time,
654 file_size_bytes: html_content.len(),
655 data_points: data.importance_values.len(),
656 created_at: chrono::Utc::now(),
657 },
658 binary_data: None,
659 })
660 }
661
662 fn render_shap_plot(
663 &self,
664 data: &ShapData,
665 config: &BackendConfig,
666 ) -> SklResult<RenderedVisualization> {
667 let start_time = std::time::Instant::now();
668
669 let plot_data = format!(
671 r#"
672 var data = [{{
673 z: {},
674 type: 'heatmap',
675 colorscale: 'RdBu'
676 }}];
677
678 var layout = {{
679 title: '{}',
680 xaxis: {{ title: 'Features' }},
681 yaxis: {{ title: 'Instances' }},
682 width: {},
683 height: {}
684 }};
685
686 Plotly.newPlot('plot', data, layout);
687 "#,
688 serde_json::to_string(&data.shap_values.to_owned().into_raw_vec()).unwrap(),
689 data.title,
690 config.width,
691 config.height
692 );
693
694 let html_content = self.generate_html_template(&data.title, &plot_data);
695 let render_time = start_time.elapsed().as_millis() as u64;
696
697 Ok(RenderedVisualization {
698 content: html_content.clone(),
699 format: OutputFormat::Html,
700 metadata: VisualizationMetadata {
701 backend: self.name.clone(),
702 render_time_ms: render_time,
703 file_size_bytes: html_content.len(),
704 data_points: data.shap_values.len(),
705 created_at: chrono::Utc::now(),
706 },
707 binary_data: None,
708 })
709 }
710
711 fn render_partial_dependence(
712 &self,
713 data: &PartialDependenceData,
714 config: &BackendConfig,
715 ) -> SklResult<RenderedVisualization> {
716 let start_time = std::time::Instant::now();
717
718 let x_data: Vec<String> = data
719 .feature_values
720 .iter()
721 .map(|val| val.to_string())
722 .collect();
723 let y_data: Vec<String> = data.pd_values.iter().map(|val| val.to_string()).collect();
724
725 let plot_data = format!(
726 r#"
727 var data = [{{
728 x: [{}],
729 y: [{}],
730 type: 'scatter',
731 mode: 'lines',
732 name: 'Partial Dependence'
733 }}];
734
735 var layout = {{
736 title: '{}',
737 xaxis: {{ title: '{}' }},
738 yaxis: {{ title: 'Partial Dependence' }},
739 width: {},
740 height: {}
741 }};
742
743 Plotly.newPlot('plot', data, layout);
744 "#,
745 x_data.join(", "),
746 y_data.join(", "),
747 data.title,
748 data.feature_name,
749 config.width,
750 config.height
751 );
752
753 let html_content = self.generate_html_template(&data.title, &plot_data);
754 let render_time = start_time.elapsed().as_millis() as u64;
755
756 Ok(RenderedVisualization {
757 content: html_content.clone(),
758 format: OutputFormat::Html,
759 metadata: VisualizationMetadata {
760 backend: self.name.clone(),
761 render_time_ms: render_time,
762 file_size_bytes: html_content.len(),
763 data_points: data.feature_values.len(),
764 created_at: chrono::Utc::now(),
765 },
766 binary_data: None,
767 })
768 }
769
770 fn render_comparative_plot(
771 &self,
772 data: &ComparativeData,
773 config: &BackendConfig,
774 ) -> SklResult<RenderedVisualization> {
775 let start_time = std::time::Instant::now();
776
777 let plot_data = format!(
779 r#"
780 var data = [];
781 var layout = {{
782 title: '{}',
783 width: {},
784 height: {}
785 }};
786
787 Plotly.newPlot('plot', data, layout);
788 "#,
789 data.title, config.width, config.height
790 );
791
792 let html_content = self.generate_html_template(&data.title, &plot_data);
793 let render_time = start_time.elapsed().as_millis() as u64;
794
795 Ok(RenderedVisualization {
796 content: html_content.clone(),
797 format: OutputFormat::Html,
798 metadata: VisualizationMetadata {
799 backend: self.name.clone(),
800 render_time_ms: render_time,
801 file_size_bytes: html_content.len(),
802 data_points: data.model_data.len(),
803 created_at: chrono::Utc::now(),
804 },
805 binary_data: None,
806 })
807 }
808
809 fn render_custom_plot(
810 &self,
811 data: &CustomPlotData,
812 config: &BackendConfig,
813 ) -> SklResult<RenderedVisualization> {
814 let start_time = std::time::Instant::now();
815
816 let plot_data = format!(
817 r#"
818 var data = {};
819 var layout = {{
820 title: '{}',
821 width: {},
822 height: {}
823 }};
824
825 Plotly.newPlot('plot', data, layout);
826 "#,
827 data.data, data.title, config.width, config.height
828 );
829
830 let html_content = self.generate_html_template(&data.title, &plot_data);
831 let render_time = start_time.elapsed().as_millis() as u64;
832
833 Ok(RenderedVisualization {
834 content: html_content.clone(),
835 format: OutputFormat::Html,
836 metadata: VisualizationMetadata {
837 backend: self.name.clone(),
838 render_time_ms: render_time,
839 file_size_bytes: html_content.len(),
840 data_points: 0,
841 created_at: chrono::Utc::now(),
842 },
843 binary_data: None,
844 })
845 }
846
847 fn name(&self) -> &str {
848 &self.name
849 }
850
851 fn supported_formats(&self) -> Vec<OutputFormat> {
852 vec![OutputFormat::Html]
853 }
854
855 fn supports_interactivity(&self) -> bool {
856 true
857 }
858
859 fn capabilities(&self) -> BackendCapabilities {
860 BackendCapabilities {
861 formats: vec![OutputFormat::Html],
862 interactive: true,
863 animations: true,
864 three_d: false,
865 custom_themes: true,
866 real_time_updates: true,
867 max_data_points: Some(10000),
868 }
869 }
870}
871
872#[derive(Debug)]
874pub struct JsonBackend {
875 name: String,
876}
877
878impl Default for JsonBackend {
879 fn default() -> Self {
880 Self::new()
881 }
882}
883
884impl JsonBackend {
885 pub fn new() -> Self {
887 Self {
888 name: "json".to_string(),
889 }
890 }
891}
892
893impl VisualizationBackend for JsonBackend {
894 fn render_feature_importance(
895 &self,
896 data: &FeatureImportanceData,
897 _config: &BackendConfig,
898 ) -> SklResult<RenderedVisualization> {
899 let start_time = std::time::Instant::now();
900
901 let json_content = serde_json::to_string_pretty(data).map_err(|e| {
902 crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
903 })?;
904
905 let render_time = start_time.elapsed().as_millis() as u64;
906
907 Ok(RenderedVisualization {
908 content: json_content.clone(),
909 format: OutputFormat::Json,
910 metadata: VisualizationMetadata {
911 backend: self.name.clone(),
912 render_time_ms: render_time,
913 file_size_bytes: json_content.len(),
914 data_points: data.importance_values.len(),
915 created_at: chrono::Utc::now(),
916 },
917 binary_data: None,
918 })
919 }
920
921 fn render_shap_plot(
922 &self,
923 data: &ShapData,
924 _config: &BackendConfig,
925 ) -> SklResult<RenderedVisualization> {
926 let start_time = std::time::Instant::now();
927
928 let json_content = serde_json::to_string_pretty(data).map_err(|e| {
929 crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
930 })?;
931
932 let render_time = start_time.elapsed().as_millis() as u64;
933
934 Ok(RenderedVisualization {
935 content: json_content.clone(),
936 format: OutputFormat::Json,
937 metadata: VisualizationMetadata {
938 backend: self.name.clone(),
939 render_time_ms: render_time,
940 file_size_bytes: json_content.len(),
941 data_points: data.shap_values.len(),
942 created_at: chrono::Utc::now(),
943 },
944 binary_data: None,
945 })
946 }
947
948 fn render_partial_dependence(
949 &self,
950 data: &PartialDependenceData,
951 _config: &BackendConfig,
952 ) -> SklResult<RenderedVisualization> {
953 let start_time = std::time::Instant::now();
954
955 let json_content = serde_json::to_string_pretty(data).map_err(|e| {
956 crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
957 })?;
958
959 let render_time = start_time.elapsed().as_millis() as u64;
960
961 Ok(RenderedVisualization {
962 content: json_content.clone(),
963 format: OutputFormat::Json,
964 metadata: VisualizationMetadata {
965 backend: self.name.clone(),
966 render_time_ms: render_time,
967 file_size_bytes: json_content.len(),
968 data_points: data.feature_values.len(),
969 created_at: chrono::Utc::now(),
970 },
971 binary_data: None,
972 })
973 }
974
975 fn render_comparative_plot(
976 &self,
977 data: &ComparativeData,
978 _config: &BackendConfig,
979 ) -> SklResult<RenderedVisualization> {
980 let start_time = std::time::Instant::now();
981
982 let json_content = serde_json::to_string_pretty(data).map_err(|e| {
983 crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
984 })?;
985
986 let render_time = start_time.elapsed().as_millis() as u64;
987
988 Ok(RenderedVisualization {
989 content: json_content.clone(),
990 format: OutputFormat::Json,
991 metadata: VisualizationMetadata {
992 backend: self.name.clone(),
993 render_time_ms: render_time,
994 file_size_bytes: json_content.len(),
995 data_points: data.model_data.len(),
996 created_at: chrono::Utc::now(),
997 },
998 binary_data: None,
999 })
1000 }
1001
1002 fn render_custom_plot(
1003 &self,
1004 data: &CustomPlotData,
1005 _config: &BackendConfig,
1006 ) -> SklResult<RenderedVisualization> {
1007 let start_time = std::time::Instant::now();
1008
1009 let json_content = serde_json::to_string_pretty(data).map_err(|e| {
1010 crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
1011 })?;
1012
1013 let render_time = start_time.elapsed().as_millis() as u64;
1014
1015 Ok(RenderedVisualization {
1016 content: json_content.clone(),
1017 format: OutputFormat::Json,
1018 metadata: VisualizationMetadata {
1019 backend: self.name.clone(),
1020 render_time_ms: render_time,
1021 file_size_bytes: json_content.len(),
1022 data_points: 0,
1023 created_at: chrono::Utc::now(),
1024 },
1025 binary_data: None,
1026 })
1027 }
1028
1029 fn name(&self) -> &str {
1030 &self.name
1031 }
1032
1033 fn supported_formats(&self) -> Vec<OutputFormat> {
1034 vec![OutputFormat::Json]
1035 }
1036
1037 fn supports_interactivity(&self) -> bool {
1038 false
1039 }
1040
1041 fn capabilities(&self) -> BackendCapabilities {
1042 BackendCapabilities {
1043 formats: vec![OutputFormat::Json],
1044 interactive: false,
1045 animations: false,
1046 three_d: false,
1047 custom_themes: false,
1048 real_time_updates: false,
1049 max_data_points: None,
1050 }
1051 }
1052}
1053
1054#[derive(Debug)]
1056pub struct AsciiBackend {
1057 name: String,
1058}
1059
1060impl Default for AsciiBackend {
1061 fn default() -> Self {
1062 Self::new()
1063 }
1064}
1065
1066impl AsciiBackend {
1067 pub fn new() -> Self {
1069 Self {
1070 name: "ascii".to_string(),
1071 }
1072 }
1073
1074 fn generate_ascii_bar_chart(
1076 &self,
1077 labels: &[String],
1078 values: &[Float],
1079 width: usize,
1080 height: usize,
1081 ) -> String {
1082 let max_value = values.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1083 let bar_width = (width - 20) / labels.len().max(1);
1084 let scale = (height - 5) as Float / max_value;
1085
1086 let mut result = String::new();
1087
1088 for (i, (label, &value)) in labels.iter().zip(values.iter()).enumerate() {
1090 let bar_length = (value * scale / max_value * 50.0) as usize;
1091 let bar = "█".repeat(bar_length);
1092 result.push_str(&format!("{:15} │{:<50} {:.3}\n", label, bar, value));
1093 }
1094
1095 result
1096 }
1097}
1098
1099impl VisualizationBackend for AsciiBackend {
1100 fn render_feature_importance(
1101 &self,
1102 data: &FeatureImportanceData,
1103 config: &BackendConfig,
1104 ) -> SklResult<RenderedVisualization> {
1105 let start_time = std::time::Instant::now();
1106
1107 let ascii_content = format!(
1108 "{}\n{}\n{}\n{}",
1109 "=".repeat(60),
1110 data.title,
1111 "=".repeat(60),
1112 self.generate_ascii_bar_chart(
1113 &data.feature_names,
1114 &data.importance_values,
1115 config.width,
1116 config.height,
1117 )
1118 );
1119
1120 let render_time = start_time.elapsed().as_millis() as u64;
1121
1122 Ok(RenderedVisualization {
1123 content: ascii_content.clone(),
1124 format: OutputFormat::Ascii,
1125 metadata: VisualizationMetadata {
1126 backend: self.name.clone(),
1127 render_time_ms: render_time,
1128 file_size_bytes: ascii_content.len(),
1129 data_points: data.importance_values.len(),
1130 created_at: chrono::Utc::now(),
1131 },
1132 binary_data: None,
1133 })
1134 }
1135
1136 fn render_shap_plot(
1137 &self,
1138 data: &ShapData,
1139 _config: &BackendConfig,
1140 ) -> SklResult<RenderedVisualization> {
1141 let start_time = std::time::Instant::now();
1142
1143 let ascii_content = format!(
1144 "{}\n{}\n{}\nSHAP Values: {} instances x {} features\n",
1145 "=".repeat(60),
1146 data.title,
1147 "=".repeat(60),
1148 data.shap_values.nrows(),
1149 data.shap_values.ncols()
1150 );
1151
1152 let render_time = start_time.elapsed().as_millis() as u64;
1153
1154 Ok(RenderedVisualization {
1155 content: ascii_content.clone(),
1156 format: OutputFormat::Ascii,
1157 metadata: VisualizationMetadata {
1158 backend: self.name.clone(),
1159 render_time_ms: render_time,
1160 file_size_bytes: ascii_content.len(),
1161 data_points: data.shap_values.len(),
1162 created_at: chrono::Utc::now(),
1163 },
1164 binary_data: None,
1165 })
1166 }
1167
1168 fn render_partial_dependence(
1169 &self,
1170 data: &PartialDependenceData,
1171 _config: &BackendConfig,
1172 ) -> SklResult<RenderedVisualization> {
1173 let start_time = std::time::Instant::now();
1174
1175 let ascii_content = format!(
1176 "{}\n{}\n{}\nPartial Dependence for feature: {}\n",
1177 "=".repeat(60),
1178 data.title,
1179 "=".repeat(60),
1180 data.feature_name
1181 );
1182
1183 let render_time = start_time.elapsed().as_millis() as u64;
1184
1185 Ok(RenderedVisualization {
1186 content: ascii_content.clone(),
1187 format: OutputFormat::Ascii,
1188 metadata: VisualizationMetadata {
1189 backend: self.name.clone(),
1190 render_time_ms: render_time,
1191 file_size_bytes: ascii_content.len(),
1192 data_points: data.feature_values.len(),
1193 created_at: chrono::Utc::now(),
1194 },
1195 binary_data: None,
1196 })
1197 }
1198
1199 fn render_comparative_plot(
1200 &self,
1201 data: &ComparativeData,
1202 _config: &BackendConfig,
1203 ) -> SklResult<RenderedVisualization> {
1204 let start_time = std::time::Instant::now();
1205
1206 let ascii_content = format!(
1207 "{}\n{}\n{}\nComparative plot with {} models\n",
1208 "=".repeat(60),
1209 data.title,
1210 "=".repeat(60),
1211 data.model_data.len()
1212 );
1213
1214 let render_time = start_time.elapsed().as_millis() as u64;
1215
1216 Ok(RenderedVisualization {
1217 content: ascii_content.clone(),
1218 format: OutputFormat::Ascii,
1219 metadata: VisualizationMetadata {
1220 backend: self.name.clone(),
1221 render_time_ms: render_time,
1222 file_size_bytes: ascii_content.len(),
1223 data_points: data.model_data.len(),
1224 created_at: chrono::Utc::now(),
1225 },
1226 binary_data: None,
1227 })
1228 }
1229
1230 fn render_custom_plot(
1231 &self,
1232 data: &CustomPlotData,
1233 _config: &BackendConfig,
1234 ) -> SklResult<RenderedVisualization> {
1235 let start_time = std::time::Instant::now();
1236
1237 let ascii_content = format!(
1238 "{}\n{}\n{}\nCustom plot type: {}\n",
1239 "=".repeat(60),
1240 data.title,
1241 "=".repeat(60),
1242 data.plot_type
1243 );
1244
1245 let render_time = start_time.elapsed().as_millis() as u64;
1246
1247 Ok(RenderedVisualization {
1248 content: ascii_content.clone(),
1249 format: OutputFormat::Ascii,
1250 metadata: VisualizationMetadata {
1251 backend: self.name.clone(),
1252 render_time_ms: render_time,
1253 file_size_bytes: ascii_content.len(),
1254 data_points: 0,
1255 created_at: chrono::Utc::now(),
1256 },
1257 binary_data: None,
1258 })
1259 }
1260
1261 fn name(&self) -> &str {
1262 &self.name
1263 }
1264
1265 fn supported_formats(&self) -> Vec<OutputFormat> {
1266 vec![OutputFormat::Ascii, OutputFormat::Unicode]
1267 }
1268
1269 fn supports_interactivity(&self) -> bool {
1270 false
1271 }
1272
1273 fn capabilities(&self) -> BackendCapabilities {
1274 BackendCapabilities {
1275 formats: vec![OutputFormat::Ascii, OutputFormat::Unicode],
1276 interactive: false,
1277 animations: false,
1278 three_d: false,
1279 custom_themes: false,
1280 real_time_updates: false,
1281 max_data_points: Some(1000),
1282 }
1283 }
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288 use super::*;
1289 use scirs2_core::ndarray::{array, Array2};
1291
1292 #[test]
1293 fn test_backend_registry() {
1294 let mut registry = BackendRegistry::new();
1295
1296 registry.register_backend(HtmlBackend::new());
1298 registry.register_backend(JsonBackend::new());
1299 registry.register_backend(AsciiBackend::new());
1300
1301 assert!(registry.get_backend("html").is_some());
1303 assert!(registry.get_backend("json").is_some());
1304 assert!(registry.get_backend("ascii").is_some());
1305 assert!(registry.get_backend("nonexistent").is_none());
1306
1307 assert!(registry.get_default_backend().is_some());
1309
1310 let backends = registry.list_backends();
1312 assert_eq!(backends.len(), 3);
1313 assert!(backends.contains(&"html".to_string()));
1314 assert!(backends.contains(&"json".to_string()));
1315 assert!(backends.contains(&"ascii".to_string()));
1316 }
1317
1318 #[test]
1319 fn test_visualization_renderer() {
1320 let mut renderer = VisualizationRenderer::with_default_backends();
1321
1322 let data = FeatureImportanceData {
1324 feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1325 importance_values: vec![0.6, 0.4],
1326 std_values: None,
1327 plot_type: FeatureImportanceType::Bar,
1328 title: "Test Plot".to_string(),
1329 x_label: "Features".to_string(),
1330 y_label: "Importance".to_string(),
1331 };
1332
1333 let config = BackendConfig::default();
1334 let plot_type = PlotType::FeatureImportance(data);
1335
1336 let result = renderer.render_with_backend("html", plot_type.clone(), &config);
1338 assert!(result.is_ok());
1339 let rendered = result.unwrap();
1340 assert_eq!(rendered.format, OutputFormat::Html);
1341 assert!(rendered.content.contains("Test Plot"));
1342
1343 let result = renderer.render_with_backend("json", plot_type.clone(), &config);
1345 assert!(result.is_ok());
1346 let rendered = result.unwrap();
1347 assert_eq!(rendered.format, OutputFormat::Json);
1348
1349 let result = renderer.render_with_backend("ascii", plot_type, &config);
1351 assert!(result.is_ok());
1352 let rendered = result.unwrap();
1353 assert_eq!(rendered.format, OutputFormat::Ascii);
1354 }
1355
1356 #[test]
1357 fn test_html_backend_feature_importance() {
1358 let backend = HtmlBackend::new();
1359 let data = FeatureImportanceData {
1360 feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1361 importance_values: vec![0.6, 0.4],
1362 std_values: None,
1363 plot_type: FeatureImportanceType::Bar,
1364 title: "Test Plot".to_string(),
1365 x_label: "Features".to_string(),
1366 y_label: "Importance".to_string(),
1367 };
1368
1369 let config = BackendConfig::default();
1370 let result = backend.render_feature_importance(&data, &config);
1371
1372 assert!(result.is_ok());
1373 let rendered = result.unwrap();
1374 assert_eq!(rendered.format, OutputFormat::Html);
1375 assert!(rendered.content.contains("Test Plot"));
1376 assert!(rendered.content.contains("Plotly"));
1377 assert!(rendered.metadata.data_points == 2);
1378 }
1379
1380 #[test]
1381 fn test_json_backend_feature_importance() {
1382 let backend = JsonBackend::new();
1383 let data = FeatureImportanceData {
1384 feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1385 importance_values: vec![0.6, 0.4],
1386 std_values: None,
1387 plot_type: FeatureImportanceType::Bar,
1388 title: "Test Plot".to_string(),
1389 x_label: "Features".to_string(),
1390 y_label: "Importance".to_string(),
1391 };
1392
1393 let config = BackendConfig::default();
1394 let result = backend.render_feature_importance(&data, &config);
1395
1396 assert!(result.is_ok());
1397 let rendered = result.unwrap();
1398 assert_eq!(rendered.format, OutputFormat::Json);
1399 assert!(rendered.content.contains("Feature1"));
1400 assert!(rendered.content.contains("Feature2"));
1401 assert!(rendered.metadata.data_points == 2);
1402 }
1403
1404 #[test]
1405 fn test_ascii_backend_feature_importance() {
1406 let backend = AsciiBackend::new();
1407 let data = FeatureImportanceData {
1408 feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1409 importance_values: vec![0.6, 0.4],
1410 std_values: None,
1411 plot_type: FeatureImportanceType::Bar,
1412 title: "Test Plot".to_string(),
1413 x_label: "Features".to_string(),
1414 y_label: "Importance".to_string(),
1415 };
1416
1417 let config = BackendConfig::default();
1418 let result = backend.render_feature_importance(&data, &config);
1419
1420 assert!(result.is_ok());
1421 let rendered = result.unwrap();
1422 assert_eq!(rendered.format, OutputFormat::Ascii);
1423 assert!(rendered.content.contains("Test Plot"));
1424 assert!(rendered.content.contains("Feature1"));
1425 assert!(rendered.content.contains("Feature2"));
1426 assert!(rendered.metadata.data_points == 2);
1427 }
1428
1429 #[test]
1430 fn test_shap_data_creation() {
1431 let shap_values =
1432 Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
1433 let feature_values =
1434 Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1435
1436 let data = ShapData {
1437 shap_values,
1438 feature_values,
1439 feature_names: vec!["F1".to_string(), "F2".to_string(), "F3".to_string()],
1440 instance_names: vec!["I1".to_string(), "I2".to_string()],
1441 plot_type: ShapPlotType::Summary,
1442 title: "SHAP Test".to_string(),
1443 };
1444
1445 assert_eq!(data.shap_values.nrows(), 2);
1446 assert_eq!(data.shap_values.ncols(), 3);
1447 assert_eq!(data.feature_names.len(), 3);
1448 assert_eq!(data.instance_names.len(), 2);
1449 }
1450
1451 #[test]
1452 fn test_backend_capabilities() {
1453 let html_backend = HtmlBackend::new();
1454 let json_backend = JsonBackend::new();
1455 let ascii_backend = AsciiBackend::new();
1456
1457 let html_caps = html_backend.capabilities();
1458 assert!(html_caps.interactive);
1459 assert!(html_caps.animations);
1460 assert!(html_caps.real_time_updates);
1461
1462 let json_caps = json_backend.capabilities();
1463 assert!(!json_caps.interactive);
1464 assert!(!json_caps.animations);
1465 assert!(!json_caps.real_time_updates);
1466
1467 let ascii_caps = ascii_backend.capabilities();
1468 assert!(!ascii_caps.interactive);
1469 assert!(!ascii_caps.animations);
1470 assert!(!ascii_caps.real_time_updates);
1471 }
1472
1473 #[test]
1474 fn test_find_backends_for_format() {
1475 let mut registry = BackendRegistry::new();
1476 registry.register_backend(HtmlBackend::new());
1477 registry.register_backend(JsonBackend::new());
1478 registry.register_backend(AsciiBackend::new());
1479
1480 let html_backends = registry.find_backends_for_format(OutputFormat::Html);
1481 assert_eq!(html_backends.len(), 1);
1482 assert!(html_backends.contains(&"html".to_string()));
1483
1484 let json_backends = registry.find_backends_for_format(OutputFormat::Json);
1485 assert_eq!(json_backends.len(), 1);
1486 assert!(json_backends.contains(&"json".to_string()));
1487
1488 let ascii_backends = registry.find_backends_for_format(OutputFormat::Ascii);
1489 assert_eq!(ascii_backends.len(), 1);
1490 assert!(ascii_backends.contains(&"ascii".to_string()));
1491 }
1492}