sklears_inspection/
external_visualizations.rs

1//! External Visualization Library Integrations
2//!
3//! This module provides integrations with popular external visualization libraries
4//! such as Plotly, D3.js, Vega-Lite, and others for advanced interactive visualizations.
5
6use crate::{
7    visualization_backend::{
8        BackendCapabilities, BackendConfig, ComparativeData, CustomPlotData, FeatureImportanceData,
9        OutputFormat, PartialDependenceData, RenderedVisualization, ShapData, VisualizationBackend,
10        VisualizationMetadata,
11    },
12    Float, SklResult, SklearsError,
13};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17/// Plotly backend for rich interactive visualizations
18#[derive(Debug)]
19pub struct PlotlyBackend {
20    config: PlotlyConfig,
21}
22
23/// Configuration for Plotly backend
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PlotlyConfig {
26    /// Plotly library version
27    pub version: String,
28    /// CDN URL for Plotly.js
29    pub cdn_url: String,
30    /// Default plot configuration
31    pub default_config: PlotlyPlotConfig,
32    /// Whether to include responsive behavior
33    pub responsive: bool,
34    /// Custom JavaScript code to include
35    pub custom_js: Option<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PlotlyPlotConfig {
40    /// Display mode bar
41    pub display_mode_bar: bool,
42    /// Static plot (non-interactive)
43    pub static_plot: bool,
44    /// Show tips
45    pub show_tips: bool,
46    /// Editable
47    pub editable: bool,
48}
49
50impl Default for PlotlyConfig {
51    fn default() -> Self {
52        Self {
53            version: "2.27.0".to_string(),
54            cdn_url: "https://cdn.plot.ly/plotly-2.27.0.min.js".to_string(),
55            default_config: PlotlyPlotConfig {
56                display_mode_bar: true,
57                static_plot: false,
58                show_tips: true,
59                editable: false,
60            },
61            responsive: true,
62            custom_js: None,
63        }
64    }
65}
66
67impl PlotlyBackend {
68    /// Create a new Plotly backend
69    pub fn new(config: PlotlyConfig) -> Self {
70        Self { config }
71    }
72
73    /// Generate Plotly JavaScript for feature importance
74    fn generate_feature_importance_plotly(
75        &self,
76        data: &FeatureImportanceData,
77        config: &BackendConfig,
78    ) -> SklResult<String> {
79        let mut x_values = Vec::new();
80        let mut y_values = Vec::new();
81        let mut error_y = Vec::new();
82
83        for (i, &importance) in data.importance_values.iter().enumerate() {
84            x_values.push(
85                data.feature_names
86                    .get(i)
87                    .cloned()
88                    .unwrap_or_else(|| format!("Feature {}", i)),
89            );
90            y_values.push(importance);
91            if let Some(ref std_values) = data.std_values {
92                error_y.push(std_values.get(i).copied().unwrap_or(0.0));
93            } else {
94                error_y.push(0.0);
95            }
96        }
97
98        let has_error = data.std_values.is_some() && !error_y.iter().all(|&x| x == 0.0);
99
100        let plot_data = if has_error {
101            format!(
102                r#"{{
103                    x: {:?},
104                    y: {:?},
105                    error_y: {{
106                        type: 'data',
107                        array: {:?},
108                        visible: true
109                    }},
110                    type: 'bar',
111                    name: 'Feature Importance',
112                    marker: {{
113                        color: 'rgba(158,202,225,0.8)',
114                        line: {{
115                            color: 'rgba(8,48,107,1.0)',
116                            width: 1.5
117                        }}
118                    }}
119                }}"#,
120                x_values, y_values, error_y
121            )
122        } else {
123            format!(
124                r#"{{
125                    x: {:?},
126                    y: {:?},
127                    type: 'bar',
128                    name: 'Feature Importance',
129                    marker: {{
130                        color: 'rgba(158,202,225,0.8)',
131                        line: {{
132                            color: 'rgba(8,48,107,1.0)',
133                            width: 1.5
134                        }}
135                    }}
136                }}"#,
137                x_values, y_values
138            )
139        };
140
141        let layout = format!(
142            r#"{{
143                title: 'Feature Importance',
144                xaxis: {{
145                    title: 'Features',
146                    tickangle: -45
147                }},
148                yaxis: {{
149                    title: 'Importance Score'
150                }},
151                width: {},
152                height: {},
153                margin: {{
154                    l: 60,
155                    r: 30,
156                    b: 120,
157                    t: 60
158                }}
159            }}"#,
160            config.width, config.height
161        );
162
163        let plot_config = format!(
164            r#"{{
165                displayModeBar: {},
166                staticPlot: {},
167                showTips: {},
168                editable: {},
169                responsive: {}
170            }}"#,
171            self.config.default_config.display_mode_bar,
172            self.config.default_config.static_plot,
173            self.config.default_config.show_tips,
174            self.config.default_config.editable,
175            self.config.responsive
176        );
177
178        Ok(format!(
179            r#"
180            var data = [{}];
181            var layout = {};
182            var config = {};
183            Plotly.newPlot('plotly-div', data, layout, config);
184            "#,
185            plot_data, layout, plot_config
186        ))
187    }
188
189    /// Generate Plotly JavaScript for SHAP values
190    fn generate_shap_plotly(&self, data: &ShapData, config: &BackendConfig) -> SklResult<String> {
191        let x_values: Vec<String> = data.feature_names.clone();
192        // Use first instance's SHAP values for visualization
193        let shap_values: Vec<Float> = if data.shap_values.nrows() > 0 {
194            data.shap_values.row(0).to_vec()
195        } else {
196            vec![0.0; data.feature_names.len()]
197        };
198
199        let plot_data = format!(
200            r#"{{
201                x: {:?},
202                y: {:?},
203                type: 'bar',
204                name: 'SHAP Values',
205                marker: {{
206                    color: {:?},
207                    colorscale: 'RdBu',
208                    line: {{
209                        color: 'rgba(0,0,0,0.2)',
210                        width: 1
211                    }}
212                }}
213            }}"#,
214            x_values,
215            shap_values,
216            shap_values
217                .iter()
218                .map(|&v| if v >= 0.0 {
219                    "rgba(255,0,0,0.8)"
220                } else {
221                    "rgba(0,0,255,0.8)"
222                })
223                .collect::<Vec<_>>()
224        );
225
226        let layout = format!(
227            r#"{{
228                title: 'SHAP Values',
229                xaxis: {{
230                    title: 'Features',
231                    tickangle: -45
232                }},
233                yaxis: {{
234                    title: 'SHAP Value',
235                    zeroline: true,
236                    zerolinecolor: 'rgb(0,0,0)',
237                    zerolinewidth: 2
238                }},
239                width: {},
240                height: {},
241                margin: {{
242                    l: 60,
243                    r: 30,
244                    b: 120,
245                    t: 60
246                }}
247            }}"#,
248            config.width, config.height
249        );
250
251        let plot_config = format!(
252            r#"{{
253                displayModeBar: {},
254                staticPlot: {},
255                responsive: {}
256            }}"#,
257            self.config.default_config.display_mode_bar,
258            self.config.default_config.static_plot,
259            self.config.responsive
260        );
261
262        Ok(format!(
263            r#"
264            var data = [{}];
265            var layout = {};
266            var config = {};
267            Plotly.newPlot('plotly-div', data, layout, config);
268            "#,
269            plot_data, layout, plot_config
270        ))
271    }
272
273    /// Generate complete HTML with Plotly
274    fn generate_complete_html(&self, js_code: &str, title: &str) -> String {
275        format!(
276            r#"<!DOCTYPE html>
277<html>
278<head>
279    <meta charset="utf-8">
280    <title>{}</title>
281    <script src="{}"></script>
282    <style>
283        body {{
284            font-family: Arial, sans-serif;
285            margin: 20px;
286        }}
287        #plotly-div {{
288            width: 100%;
289            height: 100%;
290        }}
291    </style>
292</head>
293<body>
294    <h1>{}</h1>
295    <div id="plotly-div"></div>
296    <script>
297        {}
298        {}
299    </script>
300</body>
301</html>"#,
302            title,
303            self.config.cdn_url,
304            title,
305            js_code,
306            self.config.custom_js.as_deref().unwrap_or("")
307        )
308    }
309}
310
311impl VisualizationBackend for PlotlyBackend {
312    fn render_feature_importance(
313        &self,
314        data: &FeatureImportanceData,
315        config: &BackendConfig,
316    ) -> SklResult<RenderedVisualization> {
317        let js_code = self.generate_feature_importance_plotly(data, config)?;
318
319        let content = match config.format {
320            OutputFormat::Html => self.generate_complete_html(&js_code, "Feature Importance"),
321            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
322                "type": "plotly",
323                "javascript": js_code,
324                "title": "Feature Importance"
325            }))
326            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
327            _ => {
328                return Err(SklearsError::InvalidInput(
329                    "Unsupported format for Plotly backend".to_string(),
330                ))
331            }
332        };
333
334        Ok(RenderedVisualization {
335            content,
336            format: config.format,
337            metadata: VisualizationMetadata {
338                backend: "plotly".to_string(),
339                render_time_ms: 0,
340                file_size_bytes: 0, // Will be calculated after content is moved
341                data_points: 1,     // Custom plot data count
342                created_at: chrono::Utc::now(),
343            },
344            binary_data: None,
345        })
346    }
347
348    fn render_shap_plot(
349        &self,
350        data: &ShapData,
351        config: &BackendConfig,
352    ) -> SklResult<RenderedVisualization> {
353        let js_code = self.generate_shap_plotly(data, config)?;
354
355        let content = match config.format {
356            OutputFormat::Html => self.generate_complete_html(&js_code, "SHAP Values"),
357            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
358                "type": "plotly",
359                "javascript": js_code,
360                "title": "SHAP Values"
361            }))
362            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
363            _ => {
364                return Err(SklearsError::InvalidInput(
365                    "Unsupported format for Plotly backend".to_string(),
366                ))
367            }
368        };
369
370        Ok(RenderedVisualization {
371            content,
372            format: config.format,
373            metadata: VisualizationMetadata {
374                backend: "plotly".to_string(),
375                render_time_ms: 0,
376                file_size_bytes: 0, // Will be calculated after content is moved
377                data_points: 1,     // Custom plot data count
378                created_at: chrono::Utc::now(),
379            },
380            binary_data: None,
381        })
382    }
383
384    fn render_partial_dependence(
385        &self,
386        data: &PartialDependenceData,
387        config: &BackendConfig,
388    ) -> SklResult<RenderedVisualization> {
389        let plot_data = format!(
390            r#"{{
391                x: {:?},
392                y: {:?},
393                type: 'scatter',
394                mode: 'lines+markers',
395                name: 'Partial Dependence',
396                line: {{
397                    color: 'rgb(31, 119, 180)',
398                    width: 3
399                }},
400                marker: {{
401                    color: 'rgb(31, 119, 180)',
402                    size: 6
403                }}
404            }}"#,
405            data.feature_values.to_vec(),
406            data.pd_values.to_vec()
407        );
408
409        let layout = format!(
410            r#"{{
411                title: 'Partial Dependence Plot',
412                xaxis: {{
413                    title: 'Feature Value'
414                }},
415                yaxis: {{
416                    title: 'Partial Dependence'
417                }},
418                width: {},
419                height: {}
420            }}"#,
421            config.width, config.height
422        );
423
424        let js_code = format!(
425            r#"
426            var data = [{}];
427            var layout = {};
428            Plotly.newPlot('plotly-div', data, layout);
429            "#,
430            plot_data, layout
431        );
432
433        let content = match config.format {
434            OutputFormat::Html => self.generate_complete_html(&js_code, "Partial Dependence Plot"),
435            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
436                "type": "plotly",
437                "javascript": js_code,
438                "title": "Partial Dependence Plot"
439            }))
440            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
441            _ => {
442                return Err(SklearsError::InvalidInput(
443                    "Unsupported format for Plotly backend".to_string(),
444                ))
445            }
446        };
447
448        Ok(RenderedVisualization {
449            content,
450            format: config.format,
451            metadata: VisualizationMetadata {
452                backend: "plotly".to_string(),
453                render_time_ms: 0,
454                file_size_bytes: 0, // Will be calculated after content is moved
455                data_points: data.feature_values.len(),
456                created_at: chrono::Utc::now(),
457            },
458            binary_data: None,
459        })
460    }
461
462    fn render_comparative_plot(
463        &self,
464        data: &ComparativeData,
465        config: &BackendConfig,
466    ) -> SklResult<RenderedVisualization> {
467        let mut traces = Vec::new();
468
469        for (i, (method_name, method_data)) in data.model_data.iter().enumerate() {
470            let color = match i % 6 {
471                0 => "rgb(31, 119, 180)",
472                1 => "rgb(255, 127, 14)",
473                2 => "rgb(44, 160, 44)",
474                3 => "rgb(214, 39, 40)",
475                4 => "rgb(148, 103, 189)",
476                _ => "rgb(140, 86, 75)",
477            };
478
479            // Use first row of data for visualization
480            let values: Vec<Float> = if method_data.nrows() > 0 {
481                method_data.row(0).to_vec()
482            } else {
483                vec![0.0; data.labels.len()]
484            };
485
486            traces.push(format!(
487                r#"{{
488                    x: {:?},
489                    y: {:?},
490                    type: 'bar',
491                    name: '{}',
492                    marker: {{ color: '{}' }}
493                }}"#,
494                data.labels, values, method_name, color
495            ));
496        }
497
498        let layout = format!(
499            r#"{{
500                title: 'Method Comparison',
501                xaxis: {{
502                    title: 'Features',
503                    tickangle: -45
504                }},
505                yaxis: {{
506                    title: 'Importance Score'
507                }},
508                barmode: 'group',
509                width: {},
510                height: {}
511            }}"#,
512            config.width, config.height
513        );
514
515        let js_code = format!(
516            r#"
517            var data = [{}];
518            var layout = {};
519            Plotly.newPlot('plotly-div', data, layout);
520            "#,
521            traces.join(","),
522            layout
523        );
524
525        let content = match config.format {
526            OutputFormat::Html => self.generate_complete_html(&js_code, "Method Comparison"),
527            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
528                "type": "plotly",
529                "javascript": js_code,
530                "title": "Method Comparison"
531            }))
532            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
533            _ => {
534                return Err(SklearsError::InvalidInput(
535                    "Unsupported format for Plotly backend".to_string(),
536                ))
537            }
538        };
539
540        Ok(RenderedVisualization {
541            content,
542            format: config.format,
543            metadata: VisualizationMetadata {
544                backend: "plotly".to_string(),
545                render_time_ms: 0,
546                file_size_bytes: 0, // Will be calculated after content is moved
547                data_points: data.labels.len(),
548                created_at: chrono::Utc::now(),
549            },
550            binary_data: None,
551        })
552    }
553
554    fn render_custom_plot(
555        &self,
556        data: &CustomPlotData,
557        config: &BackendConfig,
558    ) -> SklResult<RenderedVisualization> {
559        // For custom plots, we expect the data to contain Plotly-compatible JSON
560        let js_code = format!(
561            r#"
562            var data = {};
563            var layout = {};
564            Plotly.newPlot('plotly-div', data, layout);
565            "#,
566            data.data
567                .get("data")
568                .unwrap_or(&serde_json::Value::Array(vec![])),
569            data.data.get("layout").unwrap_or(&serde_json::json!({}))
570        );
571
572        let content = match config.format {
573            OutputFormat::Html => self.generate_complete_html(&js_code, "Custom Plot"),
574            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
575                "type": "plotly",
576                "javascript": js_code,
577                "title": "Custom Plot"
578            }))
579            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
580            _ => {
581                return Err(SklearsError::InvalidInput(
582                    "Unsupported format for Plotly backend".to_string(),
583                ))
584            }
585        };
586
587        Ok(RenderedVisualization {
588            content,
589            format: config.format,
590            metadata: VisualizationMetadata {
591                backend: "plotly".to_string(),
592                render_time_ms: 0,
593                file_size_bytes: 0, // Will be calculated after content is moved
594                data_points: 1,     // Custom plot
595                created_at: chrono::Utc::now(),
596            },
597            binary_data: None,
598        })
599    }
600
601    fn name(&self) -> &str {
602        "plotly"
603    }
604
605    fn supported_formats(&self) -> Vec<OutputFormat> {
606        vec![OutputFormat::Html, OutputFormat::Json]
607    }
608
609    fn supports_interactivity(&self) -> bool {
610        !self.config.default_config.static_plot
611    }
612
613    fn capabilities(&self) -> BackendCapabilities {
614        BackendCapabilities {
615            formats: self.supported_formats(),
616            interactive: self.supports_interactivity(),
617            animations: true,
618            three_d: false,
619            custom_themes: true,
620            real_time_updates: true,
621            max_data_points: Some(100000),
622        }
623    }
624}
625
626/// D3.js backend for custom interactive visualizations
627#[derive(Debug)]
628pub struct D3Backend {
629    config: D3Config,
630}
631
632#[derive(Debug, Clone, Serialize, Deserialize)]
633pub struct D3Config {
634    /// D3.js library version
635    pub version: String,
636    /// CDN URL for D3.js
637    pub cdn_url: String,
638    /// Custom CSS styles
639    pub custom_css: Option<String>,
640    /// Custom JavaScript code
641    pub custom_js: Option<String>,
642}
643
644impl Default for D3Config {
645    fn default() -> Self {
646        Self {
647            version: "7.8.5".to_string(),
648            cdn_url: "https://cdn.jsdelivr.net/npm/d3@7".to_string(),
649            custom_css: None,
650            custom_js: None,
651        }
652    }
653}
654
655impl D3Backend {
656    /// Create a new D3.js backend
657    pub fn new(config: D3Config) -> Self {
658        Self { config }
659    }
660
661    /// Generate D3.js code for feature importance
662    fn generate_feature_importance_d3(
663        &self,
664        data: &FeatureImportanceData,
665        config: &BackendConfig,
666    ) -> SklResult<String> {
667        let data_json = serde_json::to_string(&serde_json::json!({
668            "features": data.importance_values.iter().enumerate().map(|(i, &val)| {
669                serde_json::json!({
670                    "name": data.feature_names.get(i).cloned().unwrap_or_else(|| format!("Feature {}", i)),
671                    "value": val,
672                    "std": data.std_values.as_ref().map(|s| s.get(i).copied().unwrap_or(0.0))
673                })
674            }).collect::<Vec<_>>()
675        })).map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?;
676
677        Ok(format!(
678            r##"
679            const data = {};
680            
681            const margin = {{top: 20, right: 30, bottom: 70, left: 60}};
682            const width = {} - margin.left - margin.right;
683            const height = {} - margin.bottom - margin.top;
684
685            const svg = d3.select("#d3-div")
686                .append("svg")
687                .attr("width", width + margin.left + margin.right)
688                .attr("height", height + margin.top + margin.bottom);
689
690            const g = svg.append("g")
691                .attr("transform", `translate(${{margin.left}},${{margin.top}})`);
692
693            const x = d3.scaleBand()
694                .range([0, width])
695                .domain(data.features.map(d => d.name))
696                .padding(0.1);
697
698            const y = d3.scaleLinear()
699                .range([height, 0])
700                .domain([0, d3.max(data.features, d => d.value)]);
701
702            g.append("g")
703                .attr("transform", `translate(0,${{height}})`)
704                .call(d3.axisBottom(x))
705                .selectAll("text")
706                .style("text-anchor", "end")
707                .attr("dx", "-.8em")
708                .attr("dy", ".15em")
709                .attr("transform", "rotate(-45)");
710
711            g.append("g")
712                .call(d3.axisLeft(y));
713
714            g.selectAll(".bar")
715                .data(data.features)
716                .enter().append("rect")
717                .attr("class", "bar")
718                .attr("x", d => x(d.name))
719                .attr("width", x.bandwidth())
720                .attr("y", d => y(d.value))
721                .attr("height", d => height - y(d.value))
722                .attr("fill", "steelblue")
723                .on("mouseover", function(event, d) {{
724                    d3.select(this).attr("fill", "orange");
725                    
726                    const tooltip = d3.select("body").append("div")
727                        .attr("class", "tooltip")
728                        .style("opacity", 0)
729                        .style("position", "absolute")
730                        .style("background", "rgba(0,0,0,0.8)")
731                        .style("color", "white")
732                        .style("padding", "10px")
733                        .style("border-radius", "5px")
734                        .style("pointer-events", "none");
735
736                    tooltip.transition()
737                        .duration(200)
738                        .style("opacity", .9);
739                    
740                    tooltip.html(`${{d.name}}: ${{d.value.toFixed(4)}}`)
741                        .style("left", (event.pageX + 10) + "px")
742                        .style("top", (event.pageY - 28) + "px");
743                }})
744                .on("mouseout", function(d) {{
745                    d3.select(this).attr("fill", "steelblue");
746                    d3.selectAll(".tooltip").remove();
747                }});
748
749            // Add title
750            g.append("text")
751                .attr("x", width / 2)
752                .attr("y", 0 - (margin.top / 2))
753                .attr("text-anchor", "middle")
754                .style("font-size", "16px")
755                .style("font-weight", "bold")
756                .text("Feature Importance");
757
758            // Add axis labels
759            g.append("text")
760                .attr("transform", "rotate(-90)")
761                .attr("y", 0 - margin.left)
762                .attr("x", 0 - (height / 2))
763                .attr("dy", "1em")
764                .style("text-anchor", "middle")
765                .text("Importance Score");
766            "##,
767            data_json, config.width, config.height
768        ))
769    }
770
771    /// Generate complete HTML with D3.js
772    fn generate_complete_html(&self, js_code: &str, title: &str) -> String {
773        format!(
774            r#"<!DOCTYPE html>
775<html>
776<head>
777    <meta charset="utf-8">
778    <title>{}</title>
779    <script src="{}"></script>
780    <style>
781        body {{
782            font-family: Arial, sans-serif;
783            margin: 20px;
784        }}
785        #d3-div {{
786            width: 100%;
787            height: 100%;
788        }}
789        .bar:hover {{
790            fill: orange;
791        }}
792        {}
793    </style>
794</head>
795<body>
796    <h1>{}</h1>
797    <div id="d3-div"></div>
798    <script>
799        {}
800        {}
801    </script>
802</body>
803</html>"#,
804            title,
805            self.config.cdn_url,
806            self.config.custom_css.as_deref().unwrap_or(""),
807            title,
808            js_code,
809            self.config.custom_js.as_deref().unwrap_or("")
810        )
811    }
812}
813
814impl VisualizationBackend for D3Backend {
815    fn render_feature_importance(
816        &self,
817        data: &FeatureImportanceData,
818        config: &BackendConfig,
819    ) -> SklResult<RenderedVisualization> {
820        let js_code = self.generate_feature_importance_d3(data, config)?;
821
822        let content = match config.format {
823            OutputFormat::Html => self.generate_complete_html(&js_code, "Feature Importance"),
824            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
825                "type": "d3",
826                "javascript": js_code,
827                "title": "Feature Importance"
828            }))
829            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
830            _ => {
831                return Err(SklearsError::InvalidInput(
832                    "Unsupported format for D3 backend".to_string(),
833                ))
834            }
835        };
836
837        Ok(RenderedVisualization {
838            content,
839            format: config.format,
840            metadata: VisualizationMetadata {
841                backend: "d3".to_string(),
842                render_time_ms: 0,
843                file_size_bytes: 0, // Will be calculated after content is moved
844                data_points: 1,     // Custom plot data count
845                created_at: chrono::Utc::now(),
846            },
847            binary_data: None,
848        })
849    }
850
851    fn render_shap_plot(
852        &self,
853        _data: &ShapData,
854        _config: &BackendConfig,
855    ) -> SklResult<RenderedVisualization> {
856        // Implementation would be similar to feature importance but with SHAP-specific styling
857        Err(SklearsError::NotImplemented(
858            "SHAP plot for D3 backend not yet implemented".to_string(),
859        ))
860    }
861
862    fn render_partial_dependence(
863        &self,
864        _data: &PartialDependenceData,
865        _config: &BackendConfig,
866    ) -> SklResult<RenderedVisualization> {
867        // Implementation would create line plots for partial dependence
868        Err(SklearsError::NotImplemented(
869            "Partial dependence plot for D3 backend not yet implemented".to_string(),
870        ))
871    }
872
873    fn render_comparative_plot(
874        &self,
875        _data: &ComparativeData,
876        _config: &BackendConfig,
877    ) -> SklResult<RenderedVisualization> {
878        // Implementation would create grouped bar charts
879        Err(SklearsError::NotImplemented(
880            "Comparative plot for D3 backend not yet implemented".to_string(),
881        ))
882    }
883
884    fn render_custom_plot(
885        &self,
886        data: &CustomPlotData,
887        config: &BackendConfig,
888    ) -> SklResult<RenderedVisualization> {
889        // For custom plots, we expect D3.js code in the data
890        let js_code = data
891            .data
892            .get("d3_code")
893            .and_then(|v| v.as_str())
894            .ok_or_else(|| {
895                SklearsError::InvalidInput("D3 code not found in custom plot data".to_string())
896            })?;
897
898        let content = match config.format {
899            OutputFormat::Html => self.generate_complete_html(js_code, "Custom D3 Plot"),
900            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
901                "type": "d3",
902                "javascript": js_code,
903                "title": "Custom D3 Plot"
904            }))
905            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
906            _ => {
907                return Err(SklearsError::InvalidInput(
908                    "Unsupported format for D3 backend".to_string(),
909                ))
910            }
911        };
912
913        Ok(RenderedVisualization {
914            content,
915            format: config.format,
916            metadata: VisualizationMetadata {
917                backend: "d3".to_string(),
918                render_time_ms: 0,
919                file_size_bytes: 0, // Will be calculated after content is moved
920                data_points: 1,     // Custom plot data count
921                created_at: chrono::Utc::now(),
922            },
923            binary_data: None,
924        })
925    }
926
927    fn name(&self) -> &str {
928        "d3"
929    }
930
931    fn supported_formats(&self) -> Vec<OutputFormat> {
932        vec![OutputFormat::Html, OutputFormat::Json]
933    }
934
935    fn supports_interactivity(&self) -> bool {
936        true
937    }
938
939    fn capabilities(&self) -> BackendCapabilities {
940        BackendCapabilities {
941            formats: self.supported_formats(),
942            interactive: true,
943            animations: true,
944            three_d: false,
945            custom_themes: true,
946            real_time_updates: true,
947            max_data_points: Some(50000),
948        }
949    }
950}
951
952/// Vega-Lite backend for grammar of graphics visualizations
953#[derive(Debug)]
954pub struct VegaLiteBackend {
955    config: VegaLiteConfig,
956}
957
958#[derive(Debug, Clone, Serialize, Deserialize)]
959pub struct VegaLiteConfig {
960    /// Vega-Lite version
961    pub version: String,
962    /// CDN URLs for Vega-Lite
963    pub vega_url: String,
964    /// vega_lite_url
965    pub vega_lite_url: String,
966    /// vega_embed_url
967    pub vega_embed_url: String,
968    /// Default theme
969    pub theme: String,
970    /// Default configuration
971    pub default_config: VegaLiteDefaultConfig,
972}
973
974#[derive(Debug, Clone, Serialize, Deserialize)]
975pub struct VegaLiteDefaultConfig {
976    /// Actions to show in embed
977    pub actions: Vec<String>,
978    /// Whether to show tooltip
979    pub tooltip: bool,
980    /// Renderer type
981    pub renderer: String,
982}
983
984impl Default for VegaLiteConfig {
985    fn default() -> Self {
986        Self {
987            version: "5.8.0".to_string(),
988            vega_url: "https://cdn.jsdelivr.net/npm/vega@5".to_string(),
989            vega_lite_url: "https://cdn.jsdelivr.net/npm/vega-lite@5".to_string(),
990            vega_embed_url: "https://cdn.jsdelivr.net/npm/vega-embed@6".to_string(),
991            theme: "default".to_string(),
992            default_config: VegaLiteDefaultConfig {
993                actions: vec![
994                    "export".to_string(),
995                    "source".to_string(),
996                    "compiled".to_string(),
997                    "editor".to_string(),
998                ],
999                tooltip: true,
1000                renderer: "canvas".to_string(),
1001            },
1002        }
1003    }
1004}
1005
1006impl VegaLiteBackend {
1007    /// Create a new Vega-Lite backend
1008    pub fn new(config: VegaLiteConfig) -> Self {
1009        Self { config }
1010    }
1011
1012    /// Generate Vega-Lite specification for feature importance
1013    fn generate_feature_importance_vega(
1014        &self,
1015        data: &FeatureImportanceData,
1016        config: &BackendConfig,
1017    ) -> SklResult<String> {
1018        let vega_data: Vec<serde_json::Value> = data.importance_values.iter().enumerate().map(|(i, &val)| {
1019            serde_json::json!({
1020                "feature": data.feature_names.get(i).cloned().unwrap_or_else(|| format!("Feature {}", i)),
1021                "importance": val,
1022                "std": data.std_values.as_ref().map(|s| s.get(i).copied().unwrap_or(0.0))
1023            })
1024        }).collect();
1025
1026        let spec = serde_json::json!({
1027            "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
1028            "description": "Feature Importance Visualization",
1029            "width": config.width,
1030            "height": config.height,
1031            "data": {
1032                "values": vega_data
1033            },
1034            "mark": {
1035                "type": "bar",
1036                "tooltip": true
1037            },
1038            "encoding": {
1039                "x": {
1040                    "field": "feature",
1041                    "type": "nominal",
1042                    "axis": {
1043                        "title": "Features",
1044                        "labelAngle": -45
1045                    }
1046                },
1047                "y": {
1048                    "field": "importance",
1049                    "type": "quantitative",
1050                    "axis": {
1051                        "title": "Importance Score"
1052                    }
1053                },
1054                "color": {
1055                    "value": "steelblue"
1056                },
1057                "tooltip": [
1058                    {"field": "feature", "type": "nominal"},
1059                    {"field": "importance", "type": "quantitative", "format": ".4f"}
1060                ]
1061            },
1062            "title": {
1063                "text": "Feature Importance",
1064                "fontSize": 16,
1065                "anchor": "start"
1066            }
1067        });
1068
1069        let embed_options = serde_json::json!({
1070            "actions": self.config.default_config.actions,
1071            "tooltip": self.config.default_config.tooltip,
1072            "renderer": self.config.default_config.renderer,
1073            "theme": self.config.theme
1074        });
1075
1076        let js_code = format!(
1077            r#"
1078            const spec = {};
1079            const opt = {};
1080            vegaEmbed('#vega-div', spec, opt).catch(console.error);
1081            "#,
1082            serde_json::to_string_pretty(&spec)
1083                .map_err(|e| SklearsError::Other(format!("JSON error: {}", e)))?,
1084            serde_json::to_string_pretty(&embed_options)
1085                .map_err(|e| SklearsError::Other(format!("JSON error: {}", e)))?
1086        );
1087
1088        Ok(js_code)
1089    }
1090
1091    /// Generate complete HTML with Vega-Lite
1092    fn generate_complete_html(&self, js_code: &str, title: &str) -> String {
1093        format!(
1094            r#"<!DOCTYPE html>
1095<html>
1096<head>
1097    <meta charset="utf-8">
1098    <title>{}</title>
1099    <script src="{}"></script>
1100    <script src="{}"></script>
1101    <script src="{}"></script>
1102    <style>
1103        body {{
1104            font-family: Arial, sans-serif;
1105            margin: 20px;
1106        }}
1107        #vega-div {{
1108            width: 100%;
1109            height: 100%;
1110        }}
1111    </style>
1112</head>
1113<body>
1114    <h1>{}</h1>
1115    <div id="vega-div"></div>
1116    <script>
1117        {}
1118    </script>
1119</body>
1120</html>"#,
1121            title,
1122            self.config.vega_url,
1123            self.config.vega_lite_url,
1124            self.config.vega_embed_url,
1125            title,
1126            js_code
1127        )
1128    }
1129}
1130
1131impl VisualizationBackend for VegaLiteBackend {
1132    fn render_feature_importance(
1133        &self,
1134        data: &FeatureImportanceData,
1135        config: &BackendConfig,
1136    ) -> SklResult<RenderedVisualization> {
1137        let js_code = self.generate_feature_importance_vega(data, config)?;
1138
1139        let content = match config.format {
1140            OutputFormat::Html => self.generate_complete_html(&js_code, "Feature Importance"),
1141            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
1142                "type": "vega-lite",
1143                "javascript": js_code,
1144                "title": "Feature Importance"
1145            }))
1146            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
1147            _ => {
1148                return Err(SklearsError::InvalidInput(
1149                    "Unsupported format for Vega-Lite backend".to_string(),
1150                ))
1151            }
1152        };
1153
1154        Ok(RenderedVisualization {
1155            content,
1156            format: config.format,
1157            metadata: VisualizationMetadata {
1158                backend: "vega-lite".to_string(),
1159                render_time_ms: 0,
1160                file_size_bytes: 0, // Will be calculated after content is moved
1161                data_points: 1,     // Custom plot data count
1162                created_at: chrono::Utc::now(),
1163            },
1164            binary_data: None,
1165        })
1166    }
1167
1168    fn render_shap_plot(
1169        &self,
1170        _data: &ShapData,
1171        _config: &BackendConfig,
1172    ) -> SklResult<RenderedVisualization> {
1173        // Implementation would be similar to feature importance but with SHAP-specific styling
1174        Err(SklearsError::NotImplemented(
1175            "SHAP plot for Vega-Lite backend not yet implemented".to_string(),
1176        ))
1177    }
1178
1179    fn render_partial_dependence(
1180        &self,
1181        _data: &PartialDependenceData,
1182        _config: &BackendConfig,
1183    ) -> SklResult<RenderedVisualization> {
1184        // Implementation would create line plots for partial dependence
1185        Err(SklearsError::NotImplemented(
1186            "Partial dependence plot for Vega-Lite backend not yet implemented".to_string(),
1187        ))
1188    }
1189
1190    fn render_comparative_plot(
1191        &self,
1192        _data: &ComparativeData,
1193        _config: &BackendConfig,
1194    ) -> SklResult<RenderedVisualization> {
1195        // Implementation would create grouped bar charts
1196        Err(SklearsError::NotImplemented(
1197            "Comparative plot for Vega-Lite backend not yet implemented".to_string(),
1198        ))
1199    }
1200
1201    fn render_custom_plot(
1202        &self,
1203        data: &CustomPlotData,
1204        config: &BackendConfig,
1205    ) -> SklResult<RenderedVisualization> {
1206        // For custom plots, we expect Vega-Lite specification in the data
1207        let spec = data.data.get("vega_lite_spec").ok_or_else(|| {
1208            SklearsError::InvalidInput("Vega-Lite spec not found in custom plot data".to_string())
1209        })?;
1210
1211        let js_code = format!(
1212            r#"
1213            const spec = {};
1214            vegaEmbed('#vega-div', spec, {{}}).catch(console.error);
1215            "#,
1216            serde_json::to_string_pretty(spec)
1217                .map_err(|e| SklearsError::Other(format!("JSON error: {}", e)))?
1218        );
1219
1220        let content = match config.format {
1221            OutputFormat::Html => self.generate_complete_html(&js_code, "Custom Vega-Lite Plot"),
1222            OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
1223                "type": "vega-lite",
1224                "javascript": js_code,
1225                "title": "Custom Vega-Lite Plot"
1226            }))
1227            .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
1228            _ => {
1229                return Err(SklearsError::InvalidInput(
1230                    "Unsupported format for Vega-Lite backend".to_string(),
1231                ))
1232            }
1233        };
1234
1235        Ok(RenderedVisualization {
1236            content,
1237            format: config.format,
1238            metadata: VisualizationMetadata {
1239                backend: "vega-lite".to_string(),
1240                render_time_ms: 0,
1241                file_size_bytes: 0, // Will be calculated after content is moved
1242                data_points: 1,     // Custom plot data count
1243                created_at: chrono::Utc::now(),
1244            },
1245            binary_data: None,
1246        })
1247    }
1248
1249    fn name(&self) -> &str {
1250        "vega-lite"
1251    }
1252
1253    fn supported_formats(&self) -> Vec<OutputFormat> {
1254        vec![OutputFormat::Html, OutputFormat::Json]
1255    }
1256
1257    fn supports_interactivity(&self) -> bool {
1258        true
1259    }
1260
1261    fn capabilities(&self) -> BackendCapabilities {
1262        BackendCapabilities {
1263            formats: vec![OutputFormat::Html, OutputFormat::Json],
1264            interactive: true,
1265            animations: true,
1266            three_d: false,
1267            custom_themes: true,
1268            real_time_updates: false,
1269            max_data_points: Some(100000),
1270        }
1271    }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276    use super::*;
1277    use crate::visualization_backend::OutputFormat;
1278
1279    fn create_test_feature_importance_data() -> FeatureImportanceData {
1280        FeatureImportanceData {
1281            feature_names: vec![
1282                "Feature 0".to_string(),
1283                "Feature 1".to_string(),
1284                "Feature 2".to_string(),
1285            ],
1286            importance_values: vec![0.5, 0.3, 0.2],
1287            std_values: Some(vec![0.1, 0.05, 0.03]),
1288            plot_type: crate::visualization_backend::FeatureImportanceType::Bar,
1289            title: "Feature Importance".to_string(),
1290            x_label: "Features".to_string(),
1291            y_label: "Importance Score".to_string(),
1292        }
1293    }
1294
1295    fn create_test_config() -> BackendConfig {
1296        BackendConfig {
1297            format: OutputFormat::Html,
1298            width: 800,
1299            height: 600,
1300            ..Default::default()
1301        }
1302    }
1303
1304    #[test]
1305    fn test_plotly_backend_creation() {
1306        let config = PlotlyConfig::default();
1307        let backend = PlotlyBackend::new(config);
1308        assert_eq!(backend.name(), "plotly");
1309        assert!(backend.supports_interactivity());
1310    }
1311
1312    #[test]
1313    fn test_plotly_feature_importance_rendering() {
1314        let backend = PlotlyBackend::new(PlotlyConfig::default());
1315        let data = create_test_feature_importance_data();
1316        let config = create_test_config();
1317
1318        let result = backend.render_feature_importance(&data, &config);
1319        assert!(result.is_ok());
1320
1321        let visualization = result.unwrap();
1322        assert_eq!(visualization.format, OutputFormat::Html);
1323        assert!(visualization.content.contains("Plotly.newPlot"));
1324        assert_eq!(visualization.metadata.backend, "plotly");
1325    }
1326
1327    #[test]
1328    fn test_d3_backend_creation() {
1329        let config = D3Config::default();
1330        let backend = D3Backend::new(config);
1331        assert_eq!(backend.name(), "d3");
1332        assert!(backend.supports_interactivity());
1333    }
1334
1335    #[test]
1336    fn test_d3_feature_importance_rendering() {
1337        let backend = D3Backend::new(D3Config::default());
1338        let data = create_test_feature_importance_data();
1339        let config = create_test_config();
1340
1341        let result = backend.render_feature_importance(&data, &config);
1342        assert!(result.is_ok());
1343
1344        let visualization = result.unwrap();
1345        assert_eq!(visualization.format, OutputFormat::Html);
1346        assert!(visualization.content.contains("d3.select"));
1347        assert_eq!(visualization.metadata.backend, "d3");
1348    }
1349
1350    #[test]
1351    fn test_vega_lite_backend_creation() {
1352        let config = VegaLiteConfig::default();
1353        let backend = VegaLiteBackend::new(config);
1354        assert_eq!(backend.name(), "vega-lite");
1355        assert!(backend.supports_interactivity());
1356    }
1357
1358    #[test]
1359    fn test_vega_lite_feature_importance_rendering() {
1360        let backend = VegaLiteBackend::new(VegaLiteConfig::default());
1361        let data = create_test_feature_importance_data();
1362        let config = create_test_config();
1363
1364        let result = backend.render_feature_importance(&data, &config);
1365        assert!(result.is_ok());
1366
1367        let visualization = result.unwrap();
1368        assert_eq!(visualization.format, OutputFormat::Html);
1369        assert!(visualization.content.contains("vegaEmbed"));
1370        assert_eq!(visualization.metadata.backend, "vega-lite");
1371    }
1372
1373    #[test]
1374    fn test_backend_capabilities() {
1375        let plotly = PlotlyBackend::new(PlotlyConfig::default());
1376        let d3 = D3Backend::new(D3Config::default());
1377        let vega = VegaLiteBackend::new(VegaLiteConfig::default());
1378
1379        let plotly_caps = plotly.capabilities();
1380        let d3_caps = d3.capabilities();
1381        let vega_caps = vega.capabilities();
1382
1383        assert!(plotly_caps.interactive);
1384        assert!(d3_caps.interactive);
1385        assert!(vega_caps.interactive);
1386
1387        assert!(plotly_caps.animations);
1388        assert!(d3_caps.animations);
1389        assert!(vega_caps.animations);
1390    }
1391
1392    #[test]
1393    fn test_supported_formats() {
1394        let plotly = PlotlyBackend::new(PlotlyConfig::default());
1395        let formats = plotly.supported_formats();
1396
1397        assert!(formats.contains(&OutputFormat::Html));
1398        assert!(formats.contains(&OutputFormat::Json));
1399    }
1400
1401    #[test]
1402    fn test_json_output_format() {
1403        let backend = PlotlyBackend::new(PlotlyConfig::default());
1404        let data = create_test_feature_importance_data();
1405        let mut config = create_test_config();
1406        config.format = OutputFormat::Json;
1407
1408        let result = backend.render_feature_importance(&data, &config);
1409        assert!(result.is_ok());
1410
1411        let visualization = result.unwrap();
1412        assert_eq!(visualization.format, OutputFormat::Json);
1413
1414        // Verify it's valid JSON
1415        let json_value: serde_json::Value = serde_json::from_str(&visualization.content).unwrap();
1416        assert!(json_value.get("type").is_some());
1417        assert!(json_value.get("javascript").is_some());
1418    }
1419}