sklears_inspection/
visualization_backend.rs

1//! Extensible Visualization Backend System
2//!
3//! This module provides a trait-based system for pluggable visualization backends,
4//! allowing for flexible rendering to different output formats and libraries.
5
6use crate::{Float, SklResult};
7// ✅ SciRS2 Policy Compliant Import
8use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14/// Trait for visualization backends
15pub trait VisualizationBackend: Debug + Send + Sync {
16    /// Render a feature importance plot
17    fn render_feature_importance(
18        &self,
19        data: &FeatureImportanceData,
20        config: &BackendConfig,
21    ) -> SklResult<RenderedVisualization>;
22
23    /// Render a SHAP plot
24    fn render_shap_plot(
25        &self,
26        data: &ShapData,
27        config: &BackendConfig,
28    ) -> SklResult<RenderedVisualization>;
29
30    /// Render a partial dependence plot
31    fn render_partial_dependence(
32        &self,
33        data: &PartialDependenceData,
34        config: &BackendConfig,
35    ) -> SklResult<RenderedVisualization>;
36
37    /// Render a comparative plot
38    fn render_comparative_plot(
39        &self,
40        data: &ComparativeData,
41        config: &BackendConfig,
42    ) -> SklResult<RenderedVisualization>;
43
44    /// Render a custom plot
45    fn render_custom_plot(
46        &self,
47        data: &CustomPlotData,
48        config: &BackendConfig,
49    ) -> SklResult<RenderedVisualization>;
50
51    /// Get backend name
52    fn name(&self) -> &str;
53
54    /// Get supported output formats
55    fn supported_formats(&self) -> Vec<OutputFormat>;
56
57    /// Check if backend supports interactivity
58    fn supports_interactivity(&self) -> bool;
59
60    /// Get backend capabilities
61    fn capabilities(&self) -> BackendCapabilities;
62}
63
64/// Backend configuration
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct BackendConfig {
67    /// Output format
68    pub format: OutputFormat,
69    /// Width in pixels
70    pub width: usize,
71    /// Height in pixels
72    pub height: usize,
73    /// DPI for high-resolution output
74    pub dpi: usize,
75    /// Whether to enable interactivity
76    pub interactive: bool,
77    /// Color scheme
78    pub color_scheme: ColorScheme,
79    /// Theme
80    pub theme: Theme,
81    /// Custom properties
82    pub custom_properties: HashMap<String, String>,
83}
84
85impl Default for BackendConfig {
86    fn default() -> Self {
87        Self {
88            format: OutputFormat::Html,
89            width: 800,
90            height: 600,
91            dpi: 96,
92            interactive: true,
93            color_scheme: ColorScheme::Default,
94            theme: Theme::Light,
95            custom_properties: HashMap::new(),
96        }
97    }
98}
99
100/// Output formats supported by backends
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
102pub enum OutputFormat {
103    /// Html
104    Html,
105    /// Json
106    Json,
107    /// Svg
108    Svg,
109    /// Png
110    Png,
111    /// Jpeg
112    Jpeg,
113    /// Pdf
114    Pdf,
115    /// Ascii
116    Ascii,
117    /// Unicode
118    Unicode,
119}
120
121/// Color schemes
122#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
123pub enum ColorScheme {
124    /// Default
125    Default,
126    /// Viridis
127    Viridis,
128    /// Plasma
129    Plasma,
130    /// Magma
131    Magma,
132    /// Inferno
133    Inferno,
134    /// Blues
135    Blues,
136    /// Reds
137    Reds,
138    /// Greens
139    Greens,
140    /// Categorical
141    Categorical,
142    /// Diverging
143    Diverging,
144}
145
146/// Visualization themes
147#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
148pub enum Theme {
149    /// Light
150    Light,
151    /// Dark
152    Dark,
153    /// HighContrast
154    HighContrast,
155    /// Minimal
156    Minimal,
157    /// Scientific
158    Scientific,
159}
160
161/// Backend capabilities
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct BackendCapabilities {
164    /// Supported output formats
165    pub formats: Vec<OutputFormat>,
166    /// Supports interactive features
167    pub interactive: bool,
168    /// Supports animations
169    pub animations: bool,
170    /// Supports 3D rendering
171    pub three_d: bool,
172    /// Supports custom themes
173    pub custom_themes: bool,
174    /// Supports real-time updates
175    pub real_time_updates: bool,
176    /// Maximum data points efficiently handled
177    pub max_data_points: Option<usize>,
178}
179
180/// Rendered visualization result
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct RenderedVisualization {
183    /// Rendered content
184    pub content: String,
185    /// Output format
186    pub format: OutputFormat,
187    /// Metadata about the visualization
188    pub metadata: VisualizationMetadata,
189    /// Optional binary data (for images)
190    pub binary_data: Option<Vec<u8>>,
191}
192
193/// Visualization metadata
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct VisualizationMetadata {
196    /// Backend used for rendering
197    pub backend: String,
198    /// Rendering time in milliseconds
199    pub render_time_ms: u64,
200    /// File size in bytes
201    pub file_size_bytes: usize,
202    /// Data points count
203    pub data_points: usize,
204    /// Creation timestamp
205    pub created_at: chrono::DateTime<chrono::Utc>,
206}
207
208/// Feature importance data for rendering
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct FeatureImportanceData {
211    /// Feature names
212    pub feature_names: Vec<String>,
213    /// Importance values
214    pub importance_values: Vec<Float>,
215    /// Standard deviations
216    pub std_values: Option<Vec<Float>>,
217    /// Plot type
218    pub plot_type: FeatureImportanceType,
219    /// Title
220    pub title: String,
221    /// Axis labels
222    pub x_label: String,
223    /// Y-axis label
224    pub y_label: String,
225}
226
227/// Types of feature importance plots
228#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
229pub enum FeatureImportanceType {
230    /// Bar
231    Bar,
232    /// Horizontal
233    Horizontal,
234    /// Radial
235    Radial,
236    /// TreeMap
237    TreeMap,
238    /// Waterfall
239    Waterfall,
240}
241
242/// SHAP data for rendering
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct ShapData {
245    /// SHAP values matrix (instances x features)
246    pub shap_values: Array2<Float>,
247    /// Feature values matrix
248    pub feature_values: Array2<Float>,
249    /// Feature names
250    pub feature_names: Vec<String>,
251    /// Instance names
252    pub instance_names: Vec<String>,
253    /// Plot type
254    pub plot_type: ShapPlotType,
255    /// Title
256    pub title: String,
257}
258
259/// Types of SHAP plots
260#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
261pub enum ShapPlotType {
262    /// Waterfall
263    Waterfall,
264    /// ForceLayout
265    ForceLayout,
266    /// Summary
267    Summary,
268    /// Dependence
269    Dependence,
270    /// Beeswarm
271    Beeswarm,
272    /// DecisionPlot
273    DecisionPlot,
274    /// Violin
275    Violin,
276}
277
278/// Partial dependence data for rendering
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct PartialDependenceData {
281    /// Feature values for x-axis
282    pub feature_values: Array1<Float>,
283    /// Partial dependence values
284    pub pd_values: Array1<Float>,
285    /// ICE curves (if available)
286    pub ice_curves: Option<Array2<Float>>,
287    /// Feature name
288    pub feature_name: String,
289    /// Title
290    pub title: String,
291    /// Show individual curves
292    pub show_ice: bool,
293}
294
295/// Comparative data for rendering
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ComparativeData {
298    /// Data for different models/methods
299    pub model_data: HashMap<String, Array2<Float>>,
300    /// Labels for comparison
301    pub labels: Vec<String>,
302    /// Comparison type
303    pub comparison_type: ComparisonType,
304    /// Title
305    pub title: String,
306}
307
308/// Types of comparative plots
309#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
310pub enum ComparisonType {
311    /// SideBySide
312    SideBySide,
313    /// Overlay
314    Overlay,
315    /// Difference
316    Difference,
317    /// Ratio
318    Ratio,
319    /// Ranking
320    Ranking,
321}
322
323/// Custom plot data for extensibility
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct CustomPlotData {
326    /// Raw data as JSON value
327    pub data: serde_json::Value,
328    /// Plot type identifier
329    pub plot_type: String,
330    /// Title
331    pub title: String,
332    /// Additional metadata
333    pub metadata: HashMap<String, String>,
334}
335
336/// Backend registry for managing multiple backends
337#[derive(Debug, Default)]
338pub struct BackendRegistry {
339    backends: HashMap<String, Arc<dyn VisualizationBackend>>,
340    default_backend: Option<String>,
341}
342
343impl BackendRegistry {
344    /// Create a new backend registry
345    pub fn new() -> Self {
346        Self::default()
347    }
348
349    /// Register a new backend
350    pub fn register_backend<B: VisualizationBackend + 'static>(&mut self, backend: B) {
351        let name = backend.name().to_string();
352        self.backends.insert(name.clone(), Arc::new(backend));
353
354        // Set first backend as default if none set
355        if self.default_backend.is_none() {
356            self.default_backend = Some(name);
357        }
358    }
359
360    /// Get backend by name
361    pub fn get_backend(&self, name: &str) -> Option<Arc<dyn VisualizationBackend>> {
362        self.backends.get(name).cloned()
363    }
364
365    /// Get default backend
366    pub fn get_default_backend(&self) -> Option<Arc<dyn VisualizationBackend>> {
367        self.default_backend
368            .as_ref()
369            .and_then(|name| self.backends.get(name).cloned())
370    }
371
372    /// Set default backend
373    pub fn set_default_backend(&mut self, name: &str) -> SklResult<()> {
374        if self.backends.contains_key(name) {
375            self.default_backend = Some(name.to_string());
376            Ok(())
377        } else {
378            Err(crate::SklearsError::InvalidInput(format!(
379                "Backend '{}' not found",
380                name
381            )))
382        }
383    }
384
385    /// List all registered backends
386    pub fn list_backends(&self) -> Vec<String> {
387        self.backends.keys().cloned().collect()
388    }
389
390    /// Get backend capabilities
391    pub fn get_capabilities(&self, name: &str) -> Option<BackendCapabilities> {
392        self.backends.get(name).map(|b| b.capabilities())
393    }
394
395    /// Find backends supporting a specific format
396    pub fn find_backends_for_format(&self, format: OutputFormat) -> Vec<String> {
397        self.backends
398            .iter()
399            .filter(|(_, backend)| backend.supported_formats().contains(&format))
400            .map(|(name, _)| name.clone())
401            .collect()
402    }
403}
404
405/// Visualization renderer using pluggable backends
406#[derive(Debug)]
407pub struct VisualizationRenderer {
408    registry: BackendRegistry,
409}
410
411impl VisualizationRenderer {
412    /// Create a new visualization renderer
413    pub fn new() -> Self {
414        Self {
415            registry: BackendRegistry::new(),
416        }
417    }
418
419    /// Create with default backends
420    pub fn with_default_backends() -> Self {
421        let mut renderer = Self::new();
422        renderer.register_default_backends();
423        renderer
424    }
425
426    /// Register default backends
427    pub fn register_default_backends(&mut self) {
428        self.registry.register_backend(HtmlBackend::new());
429        self.registry.register_backend(JsonBackend::new());
430        self.registry.register_backend(AsciiBackend::new());
431    }
432
433    /// Register a custom backend
434    pub fn register_backend<B: VisualizationBackend + 'static>(&mut self, backend: B) {
435        self.registry.register_backend(backend);
436    }
437
438    /// Render with specific backend
439    pub fn render_with_backend(
440        &self,
441        backend_name: &str,
442        plot_type: PlotType,
443        config: &BackendConfig,
444    ) -> SklResult<RenderedVisualization> {
445        let backend = self.registry.get_backend(backend_name).ok_or_else(|| {
446            crate::SklearsError::InvalidInput(format!("Backend '{}' not found", backend_name))
447        })?;
448
449        match plot_type {
450            PlotType::FeatureImportance(data) => backend.render_feature_importance(&data, config),
451            PlotType::Shap(data) => backend.render_shap_plot(&data, config),
452            PlotType::PartialDependence(data) => backend.render_partial_dependence(&data, config),
453            PlotType::Comparative(data) => backend.render_comparative_plot(&data, config),
454            PlotType::Custom(data) => backend.render_custom_plot(&data, config),
455        }
456    }
457
458    /// Render with default backend
459    pub fn render(
460        &self,
461        plot_type: PlotType,
462        config: &BackendConfig,
463    ) -> SklResult<RenderedVisualization> {
464        let backend = self.registry.get_default_backend().ok_or_else(|| {
465            crate::SklearsError::InvalidInput("No default backend available".to_string())
466        })?;
467
468        match plot_type {
469            PlotType::FeatureImportance(data) => backend.render_feature_importance(&data, config),
470            PlotType::Shap(data) => backend.render_shap_plot(&data, config),
471            PlotType::PartialDependence(data) => backend.render_partial_dependence(&data, config),
472            PlotType::Comparative(data) => backend.render_comparative_plot(&data, config),
473            PlotType::Custom(data) => backend.render_custom_plot(&data, config),
474        }
475    }
476
477    /// Get backend registry
478    pub fn registry(&self) -> &BackendRegistry {
479        &self.registry
480    }
481
482    /// Get mutable backend registry
483    pub fn registry_mut(&mut self) -> &mut BackendRegistry {
484        &mut self.registry
485    }
486}
487
488impl Default for VisualizationRenderer {
489    fn default() -> Self {
490        Self::with_default_backends()
491    }
492}
493
494/// Plot type enum for rendering
495#[derive(Debug, Clone)]
496pub enum PlotType {
497    /// FeatureImportance
498    FeatureImportance(FeatureImportanceData),
499    /// Shap
500    Shap(ShapData),
501    /// PartialDependence
502    PartialDependence(PartialDependenceData),
503    /// Comparative
504    Comparative(ComparativeData),
505    /// Custom
506    Custom(CustomPlotData),
507}
508
509/// HTML backend implementation
510#[derive(Debug)]
511pub struct HtmlBackend {
512    name: String,
513}
514
515impl Default for HtmlBackend {
516    fn default() -> Self {
517        Self::new()
518    }
519}
520
521impl HtmlBackend {
522    /// Create a new HTML backend
523    pub fn new() -> Self {
524        Self {
525            name: "html".to_string(),
526        }
527    }
528
529    /// Generate HTML template
530    fn generate_html_template(&self, title: &str, content: &str) -> String {
531        format!(
532            r#"<!DOCTYPE html>
533<html>
534<head>
535    <title>{}</title>
536    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
537    <style>
538        body {{ font-family: Arial, sans-serif; margin: 20px; }}
539        .plot-container {{ margin: 20px 0; }}
540        .plot-title {{ font-size: 18px; font-weight: bold; margin-bottom: 10px; }}
541    </style>
542</head>
543<body>
544    <div class="plot-container">
545        <div class="plot-title">{}</div>
546        <div id="plot">{}</div>
547    </div>
548</body>
549</html>"#,
550            title, title, content
551        )
552    }
553}
554
555impl VisualizationBackend for HtmlBackend {
556    fn render_feature_importance(
557        &self,
558        data: &FeatureImportanceData,
559        config: &BackendConfig,
560    ) -> SklResult<RenderedVisualization> {
561        let start_time = std::time::Instant::now();
562
563        // Generate Plotly.js bar chart
564        let plot_data = match data.plot_type {
565            FeatureImportanceType::Bar => {
566                let x_data: Vec<String> = data
567                    .feature_names
568                    .iter()
569                    .map(|name| format!("\"{}\"", name))
570                    .collect();
571                let y_data: Vec<String> = data
572                    .importance_values
573                    .iter()
574                    .map(|val| val.to_string())
575                    .collect();
576
577                format!(
578                    r#"
579                    var data = [{{
580                        x: [{}],
581                        y: [{}],
582                        type: 'bar',
583                        marker: {{
584                            color: '#1f77b4'
585                        }}
586                    }}];
587                    
588                    var layout = {{
589                        title: '{}',
590                        xaxis: {{ title: '{}' }},
591                        yaxis: {{ title: '{}' }},
592                        width: {},
593                        height: {}
594                    }};
595                    
596                    Plotly.newPlot('plot', data, layout);
597                    "#,
598                    x_data.join(", "),
599                    y_data.join(", "),
600                    data.title,
601                    data.x_label,
602                    data.y_label,
603                    config.width,
604                    config.height
605                )
606            }
607            _ => {
608                // Fallback to bar chart for other types
609                let x_data: Vec<String> = data
610                    .feature_names
611                    .iter()
612                    .map(|name| format!("\"{}\"", name))
613                    .collect();
614                let y_data: Vec<String> = data
615                    .importance_values
616                    .iter()
617                    .map(|val| val.to_string())
618                    .collect();
619
620                format!(
621                    r#"
622                    var data = [{{
623                        x: [{}],
624                        y: [{}],
625                        type: 'bar'
626                    }}];
627                    
628                    var layout = {{
629                        title: '{}',
630                        width: {},
631                        height: {}
632                    }};
633                    
634                    Plotly.newPlot('plot', data, layout);
635                    "#,
636                    x_data.join(", "),
637                    y_data.join(", "),
638                    data.title,
639                    config.width,
640                    config.height
641                )
642            }
643        };
644
645        let html_content = self.generate_html_template(&data.title, &plot_data);
646        let render_time = start_time.elapsed().as_millis() as u64;
647
648        Ok(RenderedVisualization {
649            content: html_content.clone(),
650            format: OutputFormat::Html,
651            metadata: VisualizationMetadata {
652                backend: self.name.clone(),
653                render_time_ms: render_time,
654                file_size_bytes: html_content.len(),
655                data_points: data.importance_values.len(),
656                created_at: chrono::Utc::now(),
657            },
658            binary_data: None,
659        })
660    }
661
662    fn render_shap_plot(
663        &self,
664        data: &ShapData,
665        config: &BackendConfig,
666    ) -> SklResult<RenderedVisualization> {
667        let start_time = std::time::Instant::now();
668
669        // Generate basic SHAP visualization
670        let plot_data = format!(
671            r#"
672            var data = [{{
673                z: {},
674                type: 'heatmap',
675                colorscale: 'RdBu'
676            }}];
677            
678            var layout = {{
679                title: '{}',
680                xaxis: {{ title: 'Features' }},
681                yaxis: {{ title: 'Instances' }},
682                width: {},
683                height: {}
684            }};
685            
686            Plotly.newPlot('plot', data, layout);
687            "#,
688            serde_json::to_string(&data.shap_values.to_owned().into_raw_vec()).unwrap(),
689            data.title,
690            config.width,
691            config.height
692        );
693
694        let html_content = self.generate_html_template(&data.title, &plot_data);
695        let render_time = start_time.elapsed().as_millis() as u64;
696
697        Ok(RenderedVisualization {
698            content: html_content.clone(),
699            format: OutputFormat::Html,
700            metadata: VisualizationMetadata {
701                backend: self.name.clone(),
702                render_time_ms: render_time,
703                file_size_bytes: html_content.len(),
704                data_points: data.shap_values.len(),
705                created_at: chrono::Utc::now(),
706            },
707            binary_data: None,
708        })
709    }
710
711    fn render_partial_dependence(
712        &self,
713        data: &PartialDependenceData,
714        config: &BackendConfig,
715    ) -> SklResult<RenderedVisualization> {
716        let start_time = std::time::Instant::now();
717
718        let x_data: Vec<String> = data
719            .feature_values
720            .iter()
721            .map(|val| val.to_string())
722            .collect();
723        let y_data: Vec<String> = data.pd_values.iter().map(|val| val.to_string()).collect();
724
725        let plot_data = format!(
726            r#"
727            var data = [{{
728                x: [{}],
729                y: [{}],
730                type: 'scatter',
731                mode: 'lines',
732                name: 'Partial Dependence'
733            }}];
734            
735            var layout = {{
736                title: '{}',
737                xaxis: {{ title: '{}' }},
738                yaxis: {{ title: 'Partial Dependence' }},
739                width: {},
740                height: {}
741            }};
742            
743            Plotly.newPlot('plot', data, layout);
744            "#,
745            x_data.join(", "),
746            y_data.join(", "),
747            data.title,
748            data.feature_name,
749            config.width,
750            config.height
751        );
752
753        let html_content = self.generate_html_template(&data.title, &plot_data);
754        let render_time = start_time.elapsed().as_millis() as u64;
755
756        Ok(RenderedVisualization {
757            content: html_content.clone(),
758            format: OutputFormat::Html,
759            metadata: VisualizationMetadata {
760                backend: self.name.clone(),
761                render_time_ms: render_time,
762                file_size_bytes: html_content.len(),
763                data_points: data.feature_values.len(),
764                created_at: chrono::Utc::now(),
765            },
766            binary_data: None,
767        })
768    }
769
770    fn render_comparative_plot(
771        &self,
772        data: &ComparativeData,
773        config: &BackendConfig,
774    ) -> SklResult<RenderedVisualization> {
775        let start_time = std::time::Instant::now();
776
777        // Simple comparative plot implementation
778        let plot_data = format!(
779            r#"
780            var data = [];
781            var layout = {{
782                title: '{}',
783                width: {},
784                height: {}
785            }};
786            
787            Plotly.newPlot('plot', data, layout);
788            "#,
789            data.title, config.width, config.height
790        );
791
792        let html_content = self.generate_html_template(&data.title, &plot_data);
793        let render_time = start_time.elapsed().as_millis() as u64;
794
795        Ok(RenderedVisualization {
796            content: html_content.clone(),
797            format: OutputFormat::Html,
798            metadata: VisualizationMetadata {
799                backend: self.name.clone(),
800                render_time_ms: render_time,
801                file_size_bytes: html_content.len(),
802                data_points: data.model_data.len(),
803                created_at: chrono::Utc::now(),
804            },
805            binary_data: None,
806        })
807    }
808
809    fn render_custom_plot(
810        &self,
811        data: &CustomPlotData,
812        config: &BackendConfig,
813    ) -> SklResult<RenderedVisualization> {
814        let start_time = std::time::Instant::now();
815
816        let plot_data = format!(
817            r#"
818            var data = {};
819            var layout = {{
820                title: '{}',
821                width: {},
822                height: {}
823            }};
824            
825            Plotly.newPlot('plot', data, layout);
826            "#,
827            data.data, data.title, config.width, config.height
828        );
829
830        let html_content = self.generate_html_template(&data.title, &plot_data);
831        let render_time = start_time.elapsed().as_millis() as u64;
832
833        Ok(RenderedVisualization {
834            content: html_content.clone(),
835            format: OutputFormat::Html,
836            metadata: VisualizationMetadata {
837                backend: self.name.clone(),
838                render_time_ms: render_time,
839                file_size_bytes: html_content.len(),
840                data_points: 0,
841                created_at: chrono::Utc::now(),
842            },
843            binary_data: None,
844        })
845    }
846
847    fn name(&self) -> &str {
848        &self.name
849    }
850
851    fn supported_formats(&self) -> Vec<OutputFormat> {
852        vec![OutputFormat::Html]
853    }
854
855    fn supports_interactivity(&self) -> bool {
856        true
857    }
858
859    fn capabilities(&self) -> BackendCapabilities {
860        BackendCapabilities {
861            formats: vec![OutputFormat::Html],
862            interactive: true,
863            animations: true,
864            three_d: false,
865            custom_themes: true,
866            real_time_updates: true,
867            max_data_points: Some(10000),
868        }
869    }
870}
871
872/// JSON backend implementation
873#[derive(Debug)]
874pub struct JsonBackend {
875    name: String,
876}
877
878impl Default for JsonBackend {
879    fn default() -> Self {
880        Self::new()
881    }
882}
883
884impl JsonBackend {
885    /// Create a new JSON backend
886    pub fn new() -> Self {
887        Self {
888            name: "json".to_string(),
889        }
890    }
891}
892
893impl VisualizationBackend for JsonBackend {
894    fn render_feature_importance(
895        &self,
896        data: &FeatureImportanceData,
897        _config: &BackendConfig,
898    ) -> SklResult<RenderedVisualization> {
899        let start_time = std::time::Instant::now();
900
901        let json_content = serde_json::to_string_pretty(data).map_err(|e| {
902            crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
903        })?;
904
905        let render_time = start_time.elapsed().as_millis() as u64;
906
907        Ok(RenderedVisualization {
908            content: json_content.clone(),
909            format: OutputFormat::Json,
910            metadata: VisualizationMetadata {
911                backend: self.name.clone(),
912                render_time_ms: render_time,
913                file_size_bytes: json_content.len(),
914                data_points: data.importance_values.len(),
915                created_at: chrono::Utc::now(),
916            },
917            binary_data: None,
918        })
919    }
920
921    fn render_shap_plot(
922        &self,
923        data: &ShapData,
924        _config: &BackendConfig,
925    ) -> SklResult<RenderedVisualization> {
926        let start_time = std::time::Instant::now();
927
928        let json_content = serde_json::to_string_pretty(data).map_err(|e| {
929            crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
930        })?;
931
932        let render_time = start_time.elapsed().as_millis() as u64;
933
934        Ok(RenderedVisualization {
935            content: json_content.clone(),
936            format: OutputFormat::Json,
937            metadata: VisualizationMetadata {
938                backend: self.name.clone(),
939                render_time_ms: render_time,
940                file_size_bytes: json_content.len(),
941                data_points: data.shap_values.len(),
942                created_at: chrono::Utc::now(),
943            },
944            binary_data: None,
945        })
946    }
947
948    fn render_partial_dependence(
949        &self,
950        data: &PartialDependenceData,
951        _config: &BackendConfig,
952    ) -> SklResult<RenderedVisualization> {
953        let start_time = std::time::Instant::now();
954
955        let json_content = serde_json::to_string_pretty(data).map_err(|e| {
956            crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
957        })?;
958
959        let render_time = start_time.elapsed().as_millis() as u64;
960
961        Ok(RenderedVisualization {
962            content: json_content.clone(),
963            format: OutputFormat::Json,
964            metadata: VisualizationMetadata {
965                backend: self.name.clone(),
966                render_time_ms: render_time,
967                file_size_bytes: json_content.len(),
968                data_points: data.feature_values.len(),
969                created_at: chrono::Utc::now(),
970            },
971            binary_data: None,
972        })
973    }
974
975    fn render_comparative_plot(
976        &self,
977        data: &ComparativeData,
978        _config: &BackendConfig,
979    ) -> SklResult<RenderedVisualization> {
980        let start_time = std::time::Instant::now();
981
982        let json_content = serde_json::to_string_pretty(data).map_err(|e| {
983            crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
984        })?;
985
986        let render_time = start_time.elapsed().as_millis() as u64;
987
988        Ok(RenderedVisualization {
989            content: json_content.clone(),
990            format: OutputFormat::Json,
991            metadata: VisualizationMetadata {
992                backend: self.name.clone(),
993                render_time_ms: render_time,
994                file_size_bytes: json_content.len(),
995                data_points: data.model_data.len(),
996                created_at: chrono::Utc::now(),
997            },
998            binary_data: None,
999        })
1000    }
1001
1002    fn render_custom_plot(
1003        &self,
1004        data: &CustomPlotData,
1005        _config: &BackendConfig,
1006    ) -> SklResult<RenderedVisualization> {
1007        let start_time = std::time::Instant::now();
1008
1009        let json_content = serde_json::to_string_pretty(data).map_err(|e| {
1010            crate::SklearsError::InvalidInput(format!("JSON serialization failed: {}", e))
1011        })?;
1012
1013        let render_time = start_time.elapsed().as_millis() as u64;
1014
1015        Ok(RenderedVisualization {
1016            content: json_content.clone(),
1017            format: OutputFormat::Json,
1018            metadata: VisualizationMetadata {
1019                backend: self.name.clone(),
1020                render_time_ms: render_time,
1021                file_size_bytes: json_content.len(),
1022                data_points: 0,
1023                created_at: chrono::Utc::now(),
1024            },
1025            binary_data: None,
1026        })
1027    }
1028
1029    fn name(&self) -> &str {
1030        &self.name
1031    }
1032
1033    fn supported_formats(&self) -> Vec<OutputFormat> {
1034        vec![OutputFormat::Json]
1035    }
1036
1037    fn supports_interactivity(&self) -> bool {
1038        false
1039    }
1040
1041    fn capabilities(&self) -> BackendCapabilities {
1042        BackendCapabilities {
1043            formats: vec![OutputFormat::Json],
1044            interactive: false,
1045            animations: false,
1046            three_d: false,
1047            custom_themes: false,
1048            real_time_updates: false,
1049            max_data_points: None,
1050        }
1051    }
1052}
1053
1054/// ASCII backend implementation for terminal output
1055#[derive(Debug)]
1056pub struct AsciiBackend {
1057    name: String,
1058}
1059
1060impl Default for AsciiBackend {
1061    fn default() -> Self {
1062        Self::new()
1063    }
1064}
1065
1066impl AsciiBackend {
1067    /// Create a new ASCII backend
1068    pub fn new() -> Self {
1069        Self {
1070            name: "ascii".to_string(),
1071        }
1072    }
1073
1074    /// Generate ASCII bar chart
1075    fn generate_ascii_bar_chart(
1076        &self,
1077        labels: &[String],
1078        values: &[Float],
1079        width: usize,
1080        height: usize,
1081    ) -> String {
1082        let max_value = values.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1083        let bar_width = (width - 20) / labels.len().max(1);
1084        let scale = (height - 5) as Float / max_value;
1085
1086        let mut result = String::new();
1087
1088        // Create horizontal bar chart
1089        for (i, (label, &value)) in labels.iter().zip(values.iter()).enumerate() {
1090            let bar_length = (value * scale / max_value * 50.0) as usize;
1091            let bar = "█".repeat(bar_length);
1092            result.push_str(&format!("{:15} │{:<50} {:.3}\n", label, bar, value));
1093        }
1094
1095        result
1096    }
1097}
1098
1099impl VisualizationBackend for AsciiBackend {
1100    fn render_feature_importance(
1101        &self,
1102        data: &FeatureImportanceData,
1103        config: &BackendConfig,
1104    ) -> SklResult<RenderedVisualization> {
1105        let start_time = std::time::Instant::now();
1106
1107        let ascii_content = format!(
1108            "{}\n{}\n{}\n{}",
1109            "=".repeat(60),
1110            data.title,
1111            "=".repeat(60),
1112            self.generate_ascii_bar_chart(
1113                &data.feature_names,
1114                &data.importance_values,
1115                config.width,
1116                config.height,
1117            )
1118        );
1119
1120        let render_time = start_time.elapsed().as_millis() as u64;
1121
1122        Ok(RenderedVisualization {
1123            content: ascii_content.clone(),
1124            format: OutputFormat::Ascii,
1125            metadata: VisualizationMetadata {
1126                backend: self.name.clone(),
1127                render_time_ms: render_time,
1128                file_size_bytes: ascii_content.len(),
1129                data_points: data.importance_values.len(),
1130                created_at: chrono::Utc::now(),
1131            },
1132            binary_data: None,
1133        })
1134    }
1135
1136    fn render_shap_plot(
1137        &self,
1138        data: &ShapData,
1139        _config: &BackendConfig,
1140    ) -> SklResult<RenderedVisualization> {
1141        let start_time = std::time::Instant::now();
1142
1143        let ascii_content = format!(
1144            "{}\n{}\n{}\nSHAP Values: {} instances x {} features\n",
1145            "=".repeat(60),
1146            data.title,
1147            "=".repeat(60),
1148            data.shap_values.nrows(),
1149            data.shap_values.ncols()
1150        );
1151
1152        let render_time = start_time.elapsed().as_millis() as u64;
1153
1154        Ok(RenderedVisualization {
1155            content: ascii_content.clone(),
1156            format: OutputFormat::Ascii,
1157            metadata: VisualizationMetadata {
1158                backend: self.name.clone(),
1159                render_time_ms: render_time,
1160                file_size_bytes: ascii_content.len(),
1161                data_points: data.shap_values.len(),
1162                created_at: chrono::Utc::now(),
1163            },
1164            binary_data: None,
1165        })
1166    }
1167
1168    fn render_partial_dependence(
1169        &self,
1170        data: &PartialDependenceData,
1171        _config: &BackendConfig,
1172    ) -> SklResult<RenderedVisualization> {
1173        let start_time = std::time::Instant::now();
1174
1175        let ascii_content = format!(
1176            "{}\n{}\n{}\nPartial Dependence for feature: {}\n",
1177            "=".repeat(60),
1178            data.title,
1179            "=".repeat(60),
1180            data.feature_name
1181        );
1182
1183        let render_time = start_time.elapsed().as_millis() as u64;
1184
1185        Ok(RenderedVisualization {
1186            content: ascii_content.clone(),
1187            format: OutputFormat::Ascii,
1188            metadata: VisualizationMetadata {
1189                backend: self.name.clone(),
1190                render_time_ms: render_time,
1191                file_size_bytes: ascii_content.len(),
1192                data_points: data.feature_values.len(),
1193                created_at: chrono::Utc::now(),
1194            },
1195            binary_data: None,
1196        })
1197    }
1198
1199    fn render_comparative_plot(
1200        &self,
1201        data: &ComparativeData,
1202        _config: &BackendConfig,
1203    ) -> SklResult<RenderedVisualization> {
1204        let start_time = std::time::Instant::now();
1205
1206        let ascii_content = format!(
1207            "{}\n{}\n{}\nComparative plot with {} models\n",
1208            "=".repeat(60),
1209            data.title,
1210            "=".repeat(60),
1211            data.model_data.len()
1212        );
1213
1214        let render_time = start_time.elapsed().as_millis() as u64;
1215
1216        Ok(RenderedVisualization {
1217            content: ascii_content.clone(),
1218            format: OutputFormat::Ascii,
1219            metadata: VisualizationMetadata {
1220                backend: self.name.clone(),
1221                render_time_ms: render_time,
1222                file_size_bytes: ascii_content.len(),
1223                data_points: data.model_data.len(),
1224                created_at: chrono::Utc::now(),
1225            },
1226            binary_data: None,
1227        })
1228    }
1229
1230    fn render_custom_plot(
1231        &self,
1232        data: &CustomPlotData,
1233        _config: &BackendConfig,
1234    ) -> SklResult<RenderedVisualization> {
1235        let start_time = std::time::Instant::now();
1236
1237        let ascii_content = format!(
1238            "{}\n{}\n{}\nCustom plot type: {}\n",
1239            "=".repeat(60),
1240            data.title,
1241            "=".repeat(60),
1242            data.plot_type
1243        );
1244
1245        let render_time = start_time.elapsed().as_millis() as u64;
1246
1247        Ok(RenderedVisualization {
1248            content: ascii_content.clone(),
1249            format: OutputFormat::Ascii,
1250            metadata: VisualizationMetadata {
1251                backend: self.name.clone(),
1252                render_time_ms: render_time,
1253                file_size_bytes: ascii_content.len(),
1254                data_points: 0,
1255                created_at: chrono::Utc::now(),
1256            },
1257            binary_data: None,
1258        })
1259    }
1260
1261    fn name(&self) -> &str {
1262        &self.name
1263    }
1264
1265    fn supported_formats(&self) -> Vec<OutputFormat> {
1266        vec![OutputFormat::Ascii, OutputFormat::Unicode]
1267    }
1268
1269    fn supports_interactivity(&self) -> bool {
1270        false
1271    }
1272
1273    fn capabilities(&self) -> BackendCapabilities {
1274        BackendCapabilities {
1275            formats: vec![OutputFormat::Ascii, OutputFormat::Unicode],
1276            interactive: false,
1277            animations: false,
1278            three_d: false,
1279            custom_themes: false,
1280            real_time_updates: false,
1281            max_data_points: Some(1000),
1282        }
1283    }
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288    use super::*;
1289    // ✅ SciRS2 Policy Compliant Import
1290    use scirs2_core::ndarray::{array, Array2};
1291
1292    #[test]
1293    fn test_backend_registry() {
1294        let mut registry = BackendRegistry::new();
1295
1296        // Register backends
1297        registry.register_backend(HtmlBackend::new());
1298        registry.register_backend(JsonBackend::new());
1299        registry.register_backend(AsciiBackend::new());
1300
1301        // Test backend retrieval
1302        assert!(registry.get_backend("html").is_some());
1303        assert!(registry.get_backend("json").is_some());
1304        assert!(registry.get_backend("ascii").is_some());
1305        assert!(registry.get_backend("nonexistent").is_none());
1306
1307        // Test default backend
1308        assert!(registry.get_default_backend().is_some());
1309
1310        // Test backend listing
1311        let backends = registry.list_backends();
1312        assert_eq!(backends.len(), 3);
1313        assert!(backends.contains(&"html".to_string()));
1314        assert!(backends.contains(&"json".to_string()));
1315        assert!(backends.contains(&"ascii".to_string()));
1316    }
1317
1318    #[test]
1319    fn test_visualization_renderer() {
1320        let mut renderer = VisualizationRenderer::with_default_backends();
1321
1322        // Test feature importance rendering
1323        let data = FeatureImportanceData {
1324            feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1325            importance_values: vec![0.6, 0.4],
1326            std_values: None,
1327            plot_type: FeatureImportanceType::Bar,
1328            title: "Test Plot".to_string(),
1329            x_label: "Features".to_string(),
1330            y_label: "Importance".to_string(),
1331        };
1332
1333        let config = BackendConfig::default();
1334        let plot_type = PlotType::FeatureImportance(data);
1335
1336        // Test HTML rendering
1337        let result = renderer.render_with_backend("html", plot_type.clone(), &config);
1338        assert!(result.is_ok());
1339        let rendered = result.unwrap();
1340        assert_eq!(rendered.format, OutputFormat::Html);
1341        assert!(rendered.content.contains("Test Plot"));
1342
1343        // Test JSON rendering
1344        let result = renderer.render_with_backend("json", plot_type.clone(), &config);
1345        assert!(result.is_ok());
1346        let rendered = result.unwrap();
1347        assert_eq!(rendered.format, OutputFormat::Json);
1348
1349        // Test ASCII rendering
1350        let result = renderer.render_with_backend("ascii", plot_type, &config);
1351        assert!(result.is_ok());
1352        let rendered = result.unwrap();
1353        assert_eq!(rendered.format, OutputFormat::Ascii);
1354    }
1355
1356    #[test]
1357    fn test_html_backend_feature_importance() {
1358        let backend = HtmlBackend::new();
1359        let data = FeatureImportanceData {
1360            feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1361            importance_values: vec![0.6, 0.4],
1362            std_values: None,
1363            plot_type: FeatureImportanceType::Bar,
1364            title: "Test Plot".to_string(),
1365            x_label: "Features".to_string(),
1366            y_label: "Importance".to_string(),
1367        };
1368
1369        let config = BackendConfig::default();
1370        let result = backend.render_feature_importance(&data, &config);
1371
1372        assert!(result.is_ok());
1373        let rendered = result.unwrap();
1374        assert_eq!(rendered.format, OutputFormat::Html);
1375        assert!(rendered.content.contains("Test Plot"));
1376        assert!(rendered.content.contains("Plotly"));
1377        assert!(rendered.metadata.data_points == 2);
1378    }
1379
1380    #[test]
1381    fn test_json_backend_feature_importance() {
1382        let backend = JsonBackend::new();
1383        let data = FeatureImportanceData {
1384            feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1385            importance_values: vec![0.6, 0.4],
1386            std_values: None,
1387            plot_type: FeatureImportanceType::Bar,
1388            title: "Test Plot".to_string(),
1389            x_label: "Features".to_string(),
1390            y_label: "Importance".to_string(),
1391        };
1392
1393        let config = BackendConfig::default();
1394        let result = backend.render_feature_importance(&data, &config);
1395
1396        assert!(result.is_ok());
1397        let rendered = result.unwrap();
1398        assert_eq!(rendered.format, OutputFormat::Json);
1399        assert!(rendered.content.contains("Feature1"));
1400        assert!(rendered.content.contains("Feature2"));
1401        assert!(rendered.metadata.data_points == 2);
1402    }
1403
1404    #[test]
1405    fn test_ascii_backend_feature_importance() {
1406        let backend = AsciiBackend::new();
1407        let data = FeatureImportanceData {
1408            feature_names: vec!["Feature1".to_string(), "Feature2".to_string()],
1409            importance_values: vec![0.6, 0.4],
1410            std_values: None,
1411            plot_type: FeatureImportanceType::Bar,
1412            title: "Test Plot".to_string(),
1413            x_label: "Features".to_string(),
1414            y_label: "Importance".to_string(),
1415        };
1416
1417        let config = BackendConfig::default();
1418        let result = backend.render_feature_importance(&data, &config);
1419
1420        assert!(result.is_ok());
1421        let rendered = result.unwrap();
1422        assert_eq!(rendered.format, OutputFormat::Ascii);
1423        assert!(rendered.content.contains("Test Plot"));
1424        assert!(rendered.content.contains("Feature1"));
1425        assert!(rendered.content.contains("Feature2"));
1426        assert!(rendered.metadata.data_points == 2);
1427    }
1428
1429    #[test]
1430    fn test_shap_data_creation() {
1431        let shap_values =
1432            Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
1433        let feature_values =
1434            Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1435
1436        let data = ShapData {
1437            shap_values,
1438            feature_values,
1439            feature_names: vec!["F1".to_string(), "F2".to_string(), "F3".to_string()],
1440            instance_names: vec!["I1".to_string(), "I2".to_string()],
1441            plot_type: ShapPlotType::Summary,
1442            title: "SHAP Test".to_string(),
1443        };
1444
1445        assert_eq!(data.shap_values.nrows(), 2);
1446        assert_eq!(data.shap_values.ncols(), 3);
1447        assert_eq!(data.feature_names.len(), 3);
1448        assert_eq!(data.instance_names.len(), 2);
1449    }
1450
1451    #[test]
1452    fn test_backend_capabilities() {
1453        let html_backend = HtmlBackend::new();
1454        let json_backend = JsonBackend::new();
1455        let ascii_backend = AsciiBackend::new();
1456
1457        let html_caps = html_backend.capabilities();
1458        assert!(html_caps.interactive);
1459        assert!(html_caps.animations);
1460        assert!(html_caps.real_time_updates);
1461
1462        let json_caps = json_backend.capabilities();
1463        assert!(!json_caps.interactive);
1464        assert!(!json_caps.animations);
1465        assert!(!json_caps.real_time_updates);
1466
1467        let ascii_caps = ascii_backend.capabilities();
1468        assert!(!ascii_caps.interactive);
1469        assert!(!ascii_caps.animations);
1470        assert!(!ascii_caps.real_time_updates);
1471    }
1472
1473    #[test]
1474    fn test_find_backends_for_format() {
1475        let mut registry = BackendRegistry::new();
1476        registry.register_backend(HtmlBackend::new());
1477        registry.register_backend(JsonBackend::new());
1478        registry.register_backend(AsciiBackend::new());
1479
1480        let html_backends = registry.find_backends_for_format(OutputFormat::Html);
1481        assert_eq!(html_backends.len(), 1);
1482        assert!(html_backends.contains(&"html".to_string()));
1483
1484        let json_backends = registry.find_backends_for_format(OutputFormat::Json);
1485        assert_eq!(json_backends.len(), 1);
1486        assert!(json_backends.contains(&"json".to_string()));
1487
1488        let ascii_backends = registry.find_backends_for_format(OutputFormat::Ascii);
1489        assert_eq!(ascii_backends.len(), 1);
1490        assert!(ascii_backends.contains(&"ascii".to_string()));
1491    }
1492}