1use crate::{
7 visualization_backend::{
8 BackendCapabilities, BackendConfig, ComparativeData, CustomPlotData, FeatureImportanceData,
9 OutputFormat, PartialDependenceData, RenderedVisualization, ShapData, VisualizationBackend,
10 VisualizationMetadata,
11 },
12 Float, SklResult, SklearsError,
13};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug)]
19pub struct PlotlyBackend {
20 config: PlotlyConfig,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PlotlyConfig {
26 pub version: String,
28 pub cdn_url: String,
30 pub default_config: PlotlyPlotConfig,
32 pub responsive: bool,
34 pub custom_js: Option<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PlotlyPlotConfig {
40 pub display_mode_bar: bool,
42 pub static_plot: bool,
44 pub show_tips: bool,
46 pub editable: bool,
48}
49
50impl Default for PlotlyConfig {
51 fn default() -> Self {
52 Self {
53 version: "2.27.0".to_string(),
54 cdn_url: "https://cdn.plot.ly/plotly-2.27.0.min.js".to_string(),
55 default_config: PlotlyPlotConfig {
56 display_mode_bar: true,
57 static_plot: false,
58 show_tips: true,
59 editable: false,
60 },
61 responsive: true,
62 custom_js: None,
63 }
64 }
65}
66
67impl PlotlyBackend {
68 pub fn new(config: PlotlyConfig) -> Self {
70 Self { config }
71 }
72
73 fn generate_feature_importance_plotly(
75 &self,
76 data: &FeatureImportanceData,
77 config: &BackendConfig,
78 ) -> SklResult<String> {
79 let mut x_values = Vec::new();
80 let mut y_values = Vec::new();
81 let mut error_y = Vec::new();
82
83 for (i, &importance) in data.importance_values.iter().enumerate() {
84 x_values.push(
85 data.feature_names
86 .get(i)
87 .cloned()
88 .unwrap_or_else(|| format!("Feature {}", i)),
89 );
90 y_values.push(importance);
91 if let Some(ref std_values) = data.std_values {
92 error_y.push(std_values.get(i).copied().unwrap_or(0.0));
93 } else {
94 error_y.push(0.0);
95 }
96 }
97
98 let has_error = data.std_values.is_some() && !error_y.iter().all(|&x| x == 0.0);
99
100 let plot_data = if has_error {
101 format!(
102 r#"{{
103 x: {:?},
104 y: {:?},
105 error_y: {{
106 type: 'data',
107 array: {:?},
108 visible: true
109 }},
110 type: 'bar',
111 name: 'Feature Importance',
112 marker: {{
113 color: 'rgba(158,202,225,0.8)',
114 line: {{
115 color: 'rgba(8,48,107,1.0)',
116 width: 1.5
117 }}
118 }}
119 }}"#,
120 x_values, y_values, error_y
121 )
122 } else {
123 format!(
124 r#"{{
125 x: {:?},
126 y: {:?},
127 type: 'bar',
128 name: 'Feature Importance',
129 marker: {{
130 color: 'rgba(158,202,225,0.8)',
131 line: {{
132 color: 'rgba(8,48,107,1.0)',
133 width: 1.5
134 }}
135 }}
136 }}"#,
137 x_values, y_values
138 )
139 };
140
141 let layout = format!(
142 r#"{{
143 title: 'Feature Importance',
144 xaxis: {{
145 title: 'Features',
146 tickangle: -45
147 }},
148 yaxis: {{
149 title: 'Importance Score'
150 }},
151 width: {},
152 height: {},
153 margin: {{
154 l: 60,
155 r: 30,
156 b: 120,
157 t: 60
158 }}
159 }}"#,
160 config.width, config.height
161 );
162
163 let plot_config = format!(
164 r#"{{
165 displayModeBar: {},
166 staticPlot: {},
167 showTips: {},
168 editable: {},
169 responsive: {}
170 }}"#,
171 self.config.default_config.display_mode_bar,
172 self.config.default_config.static_plot,
173 self.config.default_config.show_tips,
174 self.config.default_config.editable,
175 self.config.responsive
176 );
177
178 Ok(format!(
179 r#"
180 var data = [{}];
181 var layout = {};
182 var config = {};
183 Plotly.newPlot('plotly-div', data, layout, config);
184 "#,
185 plot_data, layout, plot_config
186 ))
187 }
188
189 fn generate_shap_plotly(&self, data: &ShapData, config: &BackendConfig) -> SklResult<String> {
191 let x_values: Vec<String> = data.feature_names.clone();
192 let shap_values: Vec<Float> = if data.shap_values.nrows() > 0 {
194 data.shap_values.row(0).to_vec()
195 } else {
196 vec![0.0; data.feature_names.len()]
197 };
198
199 let plot_data = format!(
200 r#"{{
201 x: {:?},
202 y: {:?},
203 type: 'bar',
204 name: 'SHAP Values',
205 marker: {{
206 color: {:?},
207 colorscale: 'RdBu',
208 line: {{
209 color: 'rgba(0,0,0,0.2)',
210 width: 1
211 }}
212 }}
213 }}"#,
214 x_values,
215 shap_values,
216 shap_values
217 .iter()
218 .map(|&v| if v >= 0.0 {
219 "rgba(255,0,0,0.8)"
220 } else {
221 "rgba(0,0,255,0.8)"
222 })
223 .collect::<Vec<_>>()
224 );
225
226 let layout = format!(
227 r#"{{
228 title: 'SHAP Values',
229 xaxis: {{
230 title: 'Features',
231 tickangle: -45
232 }},
233 yaxis: {{
234 title: 'SHAP Value',
235 zeroline: true,
236 zerolinecolor: 'rgb(0,0,0)',
237 zerolinewidth: 2
238 }},
239 width: {},
240 height: {},
241 margin: {{
242 l: 60,
243 r: 30,
244 b: 120,
245 t: 60
246 }}
247 }}"#,
248 config.width, config.height
249 );
250
251 let plot_config = format!(
252 r#"{{
253 displayModeBar: {},
254 staticPlot: {},
255 responsive: {}
256 }}"#,
257 self.config.default_config.display_mode_bar,
258 self.config.default_config.static_plot,
259 self.config.responsive
260 );
261
262 Ok(format!(
263 r#"
264 var data = [{}];
265 var layout = {};
266 var config = {};
267 Plotly.newPlot('plotly-div', data, layout, config);
268 "#,
269 plot_data, layout, plot_config
270 ))
271 }
272
273 fn generate_complete_html(&self, js_code: &str, title: &str) -> String {
275 format!(
276 r#"<!DOCTYPE html>
277<html>
278<head>
279 <meta charset="utf-8">
280 <title>{}</title>
281 <script src="{}"></script>
282 <style>
283 body {{
284 font-family: Arial, sans-serif;
285 margin: 20px;
286 }}
287 #plotly-div {{
288 width: 100%;
289 height: 100%;
290 }}
291 </style>
292</head>
293<body>
294 <h1>{}</h1>
295 <div id="plotly-div"></div>
296 <script>
297 {}
298 {}
299 </script>
300</body>
301</html>"#,
302 title,
303 self.config.cdn_url,
304 title,
305 js_code,
306 self.config.custom_js.as_deref().unwrap_or("")
307 )
308 }
309}
310
311impl VisualizationBackend for PlotlyBackend {
312 fn render_feature_importance(
313 &self,
314 data: &FeatureImportanceData,
315 config: &BackendConfig,
316 ) -> SklResult<RenderedVisualization> {
317 let js_code = self.generate_feature_importance_plotly(data, config)?;
318
319 let content = match config.format {
320 OutputFormat::Html => self.generate_complete_html(&js_code, "Feature Importance"),
321 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
322 "type": "plotly",
323 "javascript": js_code,
324 "title": "Feature Importance"
325 }))
326 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
327 _ => {
328 return Err(SklearsError::InvalidInput(
329 "Unsupported format for Plotly backend".to_string(),
330 ))
331 }
332 };
333
334 Ok(RenderedVisualization {
335 content,
336 format: config.format,
337 metadata: VisualizationMetadata {
338 backend: "plotly".to_string(),
339 render_time_ms: 0,
340 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
343 },
344 binary_data: None,
345 })
346 }
347
348 fn render_shap_plot(
349 &self,
350 data: &ShapData,
351 config: &BackendConfig,
352 ) -> SklResult<RenderedVisualization> {
353 let js_code = self.generate_shap_plotly(data, config)?;
354
355 let content = match config.format {
356 OutputFormat::Html => self.generate_complete_html(&js_code, "SHAP Values"),
357 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
358 "type": "plotly",
359 "javascript": js_code,
360 "title": "SHAP Values"
361 }))
362 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
363 _ => {
364 return Err(SklearsError::InvalidInput(
365 "Unsupported format for Plotly backend".to_string(),
366 ))
367 }
368 };
369
370 Ok(RenderedVisualization {
371 content,
372 format: config.format,
373 metadata: VisualizationMetadata {
374 backend: "plotly".to_string(),
375 render_time_ms: 0,
376 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
379 },
380 binary_data: None,
381 })
382 }
383
384 fn render_partial_dependence(
385 &self,
386 data: &PartialDependenceData,
387 config: &BackendConfig,
388 ) -> SklResult<RenderedVisualization> {
389 let plot_data = format!(
390 r#"{{
391 x: {:?},
392 y: {:?},
393 type: 'scatter',
394 mode: 'lines+markers',
395 name: 'Partial Dependence',
396 line: {{
397 color: 'rgb(31, 119, 180)',
398 width: 3
399 }},
400 marker: {{
401 color: 'rgb(31, 119, 180)',
402 size: 6
403 }}
404 }}"#,
405 data.feature_values.to_vec(),
406 data.pd_values.to_vec()
407 );
408
409 let layout = format!(
410 r#"{{
411 title: 'Partial Dependence Plot',
412 xaxis: {{
413 title: 'Feature Value'
414 }},
415 yaxis: {{
416 title: 'Partial Dependence'
417 }},
418 width: {},
419 height: {}
420 }}"#,
421 config.width, config.height
422 );
423
424 let js_code = format!(
425 r#"
426 var data = [{}];
427 var layout = {};
428 Plotly.newPlot('plotly-div', data, layout);
429 "#,
430 plot_data, layout
431 );
432
433 let content = match config.format {
434 OutputFormat::Html => self.generate_complete_html(&js_code, "Partial Dependence Plot"),
435 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
436 "type": "plotly",
437 "javascript": js_code,
438 "title": "Partial Dependence Plot"
439 }))
440 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
441 _ => {
442 return Err(SklearsError::InvalidInput(
443 "Unsupported format for Plotly backend".to_string(),
444 ))
445 }
446 };
447
448 Ok(RenderedVisualization {
449 content,
450 format: config.format,
451 metadata: VisualizationMetadata {
452 backend: "plotly".to_string(),
453 render_time_ms: 0,
454 file_size_bytes: 0, data_points: data.feature_values.len(),
456 created_at: chrono::Utc::now(),
457 },
458 binary_data: None,
459 })
460 }
461
462 fn render_comparative_plot(
463 &self,
464 data: &ComparativeData,
465 config: &BackendConfig,
466 ) -> SklResult<RenderedVisualization> {
467 let mut traces = Vec::new();
468
469 for (i, (method_name, method_data)) in data.model_data.iter().enumerate() {
470 let color = match i % 6 {
471 0 => "rgb(31, 119, 180)",
472 1 => "rgb(255, 127, 14)",
473 2 => "rgb(44, 160, 44)",
474 3 => "rgb(214, 39, 40)",
475 4 => "rgb(148, 103, 189)",
476 _ => "rgb(140, 86, 75)",
477 };
478
479 let values: Vec<Float> = if method_data.nrows() > 0 {
481 method_data.row(0).to_vec()
482 } else {
483 vec![0.0; data.labels.len()]
484 };
485
486 traces.push(format!(
487 r#"{{
488 x: {:?},
489 y: {:?},
490 type: 'bar',
491 name: '{}',
492 marker: {{ color: '{}' }}
493 }}"#,
494 data.labels, values, method_name, color
495 ));
496 }
497
498 let layout = format!(
499 r#"{{
500 title: 'Method Comparison',
501 xaxis: {{
502 title: 'Features',
503 tickangle: -45
504 }},
505 yaxis: {{
506 title: 'Importance Score'
507 }},
508 barmode: 'group',
509 width: {},
510 height: {}
511 }}"#,
512 config.width, config.height
513 );
514
515 let js_code = format!(
516 r#"
517 var data = [{}];
518 var layout = {};
519 Plotly.newPlot('plotly-div', data, layout);
520 "#,
521 traces.join(","),
522 layout
523 );
524
525 let content = match config.format {
526 OutputFormat::Html => self.generate_complete_html(&js_code, "Method Comparison"),
527 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
528 "type": "plotly",
529 "javascript": js_code,
530 "title": "Method Comparison"
531 }))
532 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
533 _ => {
534 return Err(SklearsError::InvalidInput(
535 "Unsupported format for Plotly backend".to_string(),
536 ))
537 }
538 };
539
540 Ok(RenderedVisualization {
541 content,
542 format: config.format,
543 metadata: VisualizationMetadata {
544 backend: "plotly".to_string(),
545 render_time_ms: 0,
546 file_size_bytes: 0, data_points: data.labels.len(),
548 created_at: chrono::Utc::now(),
549 },
550 binary_data: None,
551 })
552 }
553
554 fn render_custom_plot(
555 &self,
556 data: &CustomPlotData,
557 config: &BackendConfig,
558 ) -> SklResult<RenderedVisualization> {
559 let js_code = format!(
561 r#"
562 var data = {};
563 var layout = {};
564 Plotly.newPlot('plotly-div', data, layout);
565 "#,
566 data.data
567 .get("data")
568 .unwrap_or(&serde_json::Value::Array(vec![])),
569 data.data.get("layout").unwrap_or(&serde_json::json!({}))
570 );
571
572 let content = match config.format {
573 OutputFormat::Html => self.generate_complete_html(&js_code, "Custom Plot"),
574 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
575 "type": "plotly",
576 "javascript": js_code,
577 "title": "Custom Plot"
578 }))
579 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
580 _ => {
581 return Err(SklearsError::InvalidInput(
582 "Unsupported format for Plotly backend".to_string(),
583 ))
584 }
585 };
586
587 Ok(RenderedVisualization {
588 content,
589 format: config.format,
590 metadata: VisualizationMetadata {
591 backend: "plotly".to_string(),
592 render_time_ms: 0,
593 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
596 },
597 binary_data: None,
598 })
599 }
600
601 fn name(&self) -> &str {
602 "plotly"
603 }
604
605 fn supported_formats(&self) -> Vec<OutputFormat> {
606 vec![OutputFormat::Html, OutputFormat::Json]
607 }
608
609 fn supports_interactivity(&self) -> bool {
610 !self.config.default_config.static_plot
611 }
612
613 fn capabilities(&self) -> BackendCapabilities {
614 BackendCapabilities {
615 formats: self.supported_formats(),
616 interactive: self.supports_interactivity(),
617 animations: true,
618 three_d: false,
619 custom_themes: true,
620 real_time_updates: true,
621 max_data_points: Some(100000),
622 }
623 }
624}
625
626#[derive(Debug)]
628pub struct D3Backend {
629 config: D3Config,
630}
631
632#[derive(Debug, Clone, Serialize, Deserialize)]
633pub struct D3Config {
634 pub version: String,
636 pub cdn_url: String,
638 pub custom_css: Option<String>,
640 pub custom_js: Option<String>,
642}
643
644impl Default for D3Config {
645 fn default() -> Self {
646 Self {
647 version: "7.8.5".to_string(),
648 cdn_url: "https://cdn.jsdelivr.net/npm/d3@7".to_string(),
649 custom_css: None,
650 custom_js: None,
651 }
652 }
653}
654
655impl D3Backend {
656 pub fn new(config: D3Config) -> Self {
658 Self { config }
659 }
660
661 fn generate_feature_importance_d3(
663 &self,
664 data: &FeatureImportanceData,
665 config: &BackendConfig,
666 ) -> SklResult<String> {
667 let data_json = serde_json::to_string(&serde_json::json!({
668 "features": data.importance_values.iter().enumerate().map(|(i, &val)| {
669 serde_json::json!({
670 "name": data.feature_names.get(i).cloned().unwrap_or_else(|| format!("Feature {}", i)),
671 "value": val,
672 "std": data.std_values.as_ref().map(|s| s.get(i).copied().unwrap_or(0.0))
673 })
674 }).collect::<Vec<_>>()
675 })).map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?;
676
677 Ok(format!(
678 r##"
679 const data = {};
680
681 const margin = {{top: 20, right: 30, bottom: 70, left: 60}};
682 const width = {} - margin.left - margin.right;
683 const height = {} - margin.bottom - margin.top;
684
685 const svg = d3.select("#d3-div")
686 .append("svg")
687 .attr("width", width + margin.left + margin.right)
688 .attr("height", height + margin.top + margin.bottom);
689
690 const g = svg.append("g")
691 .attr("transform", `translate(${{margin.left}},${{margin.top}})`);
692
693 const x = d3.scaleBand()
694 .range([0, width])
695 .domain(data.features.map(d => d.name))
696 .padding(0.1);
697
698 const y = d3.scaleLinear()
699 .range([height, 0])
700 .domain([0, d3.max(data.features, d => d.value)]);
701
702 g.append("g")
703 .attr("transform", `translate(0,${{height}})`)
704 .call(d3.axisBottom(x))
705 .selectAll("text")
706 .style("text-anchor", "end")
707 .attr("dx", "-.8em")
708 .attr("dy", ".15em")
709 .attr("transform", "rotate(-45)");
710
711 g.append("g")
712 .call(d3.axisLeft(y));
713
714 g.selectAll(".bar")
715 .data(data.features)
716 .enter().append("rect")
717 .attr("class", "bar")
718 .attr("x", d => x(d.name))
719 .attr("width", x.bandwidth())
720 .attr("y", d => y(d.value))
721 .attr("height", d => height - y(d.value))
722 .attr("fill", "steelblue")
723 .on("mouseover", function(event, d) {{
724 d3.select(this).attr("fill", "orange");
725
726 const tooltip = d3.select("body").append("div")
727 .attr("class", "tooltip")
728 .style("opacity", 0)
729 .style("position", "absolute")
730 .style("background", "rgba(0,0,0,0.8)")
731 .style("color", "white")
732 .style("padding", "10px")
733 .style("border-radius", "5px")
734 .style("pointer-events", "none");
735
736 tooltip.transition()
737 .duration(200)
738 .style("opacity", .9);
739
740 tooltip.html(`${{d.name}}: ${{d.value.toFixed(4)}}`)
741 .style("left", (event.pageX + 10) + "px")
742 .style("top", (event.pageY - 28) + "px");
743 }})
744 .on("mouseout", function(d) {{
745 d3.select(this).attr("fill", "steelblue");
746 d3.selectAll(".tooltip").remove();
747 }});
748
749 // Add title
750 g.append("text")
751 .attr("x", width / 2)
752 .attr("y", 0 - (margin.top / 2))
753 .attr("text-anchor", "middle")
754 .style("font-size", "16px")
755 .style("font-weight", "bold")
756 .text("Feature Importance");
757
758 // Add axis labels
759 g.append("text")
760 .attr("transform", "rotate(-90)")
761 .attr("y", 0 - margin.left)
762 .attr("x", 0 - (height / 2))
763 .attr("dy", "1em")
764 .style("text-anchor", "middle")
765 .text("Importance Score");
766 "##,
767 data_json, config.width, config.height
768 ))
769 }
770
771 fn generate_complete_html(&self, js_code: &str, title: &str) -> String {
773 format!(
774 r#"<!DOCTYPE html>
775<html>
776<head>
777 <meta charset="utf-8">
778 <title>{}</title>
779 <script src="{}"></script>
780 <style>
781 body {{
782 font-family: Arial, sans-serif;
783 margin: 20px;
784 }}
785 #d3-div {{
786 width: 100%;
787 height: 100%;
788 }}
789 .bar:hover {{
790 fill: orange;
791 }}
792 {}
793 </style>
794</head>
795<body>
796 <h1>{}</h1>
797 <div id="d3-div"></div>
798 <script>
799 {}
800 {}
801 </script>
802</body>
803</html>"#,
804 title,
805 self.config.cdn_url,
806 self.config.custom_css.as_deref().unwrap_or(""),
807 title,
808 js_code,
809 self.config.custom_js.as_deref().unwrap_or("")
810 )
811 }
812}
813
814impl VisualizationBackend for D3Backend {
815 fn render_feature_importance(
816 &self,
817 data: &FeatureImportanceData,
818 config: &BackendConfig,
819 ) -> SklResult<RenderedVisualization> {
820 let js_code = self.generate_feature_importance_d3(data, config)?;
821
822 let content = match config.format {
823 OutputFormat::Html => self.generate_complete_html(&js_code, "Feature Importance"),
824 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
825 "type": "d3",
826 "javascript": js_code,
827 "title": "Feature Importance"
828 }))
829 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
830 _ => {
831 return Err(SklearsError::InvalidInput(
832 "Unsupported format for D3 backend".to_string(),
833 ))
834 }
835 };
836
837 Ok(RenderedVisualization {
838 content,
839 format: config.format,
840 metadata: VisualizationMetadata {
841 backend: "d3".to_string(),
842 render_time_ms: 0,
843 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
846 },
847 binary_data: None,
848 })
849 }
850
851 fn render_shap_plot(
852 &self,
853 _data: &ShapData,
854 _config: &BackendConfig,
855 ) -> SklResult<RenderedVisualization> {
856 Err(SklearsError::NotImplemented(
858 "SHAP plot for D3 backend not yet implemented".to_string(),
859 ))
860 }
861
862 fn render_partial_dependence(
863 &self,
864 _data: &PartialDependenceData,
865 _config: &BackendConfig,
866 ) -> SklResult<RenderedVisualization> {
867 Err(SklearsError::NotImplemented(
869 "Partial dependence plot for D3 backend not yet implemented".to_string(),
870 ))
871 }
872
873 fn render_comparative_plot(
874 &self,
875 _data: &ComparativeData,
876 _config: &BackendConfig,
877 ) -> SklResult<RenderedVisualization> {
878 Err(SklearsError::NotImplemented(
880 "Comparative plot for D3 backend not yet implemented".to_string(),
881 ))
882 }
883
884 fn render_custom_plot(
885 &self,
886 data: &CustomPlotData,
887 config: &BackendConfig,
888 ) -> SklResult<RenderedVisualization> {
889 let js_code = data
891 .data
892 .get("d3_code")
893 .and_then(|v| v.as_str())
894 .ok_or_else(|| {
895 SklearsError::InvalidInput("D3 code not found in custom plot data".to_string())
896 })?;
897
898 let content = match config.format {
899 OutputFormat::Html => self.generate_complete_html(js_code, "Custom D3 Plot"),
900 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
901 "type": "d3",
902 "javascript": js_code,
903 "title": "Custom D3 Plot"
904 }))
905 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
906 _ => {
907 return Err(SklearsError::InvalidInput(
908 "Unsupported format for D3 backend".to_string(),
909 ))
910 }
911 };
912
913 Ok(RenderedVisualization {
914 content,
915 format: config.format,
916 metadata: VisualizationMetadata {
917 backend: "d3".to_string(),
918 render_time_ms: 0,
919 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
922 },
923 binary_data: None,
924 })
925 }
926
927 fn name(&self) -> &str {
928 "d3"
929 }
930
931 fn supported_formats(&self) -> Vec<OutputFormat> {
932 vec![OutputFormat::Html, OutputFormat::Json]
933 }
934
935 fn supports_interactivity(&self) -> bool {
936 true
937 }
938
939 fn capabilities(&self) -> BackendCapabilities {
940 BackendCapabilities {
941 formats: self.supported_formats(),
942 interactive: true,
943 animations: true,
944 three_d: false,
945 custom_themes: true,
946 real_time_updates: true,
947 max_data_points: Some(50000),
948 }
949 }
950}
951
952#[derive(Debug)]
954pub struct VegaLiteBackend {
955 config: VegaLiteConfig,
956}
957
958#[derive(Debug, Clone, Serialize, Deserialize)]
959pub struct VegaLiteConfig {
960 pub version: String,
962 pub vega_url: String,
964 pub vega_lite_url: String,
966 pub vega_embed_url: String,
968 pub theme: String,
970 pub default_config: VegaLiteDefaultConfig,
972}
973
974#[derive(Debug, Clone, Serialize, Deserialize)]
975pub struct VegaLiteDefaultConfig {
976 pub actions: Vec<String>,
978 pub tooltip: bool,
980 pub renderer: String,
982}
983
984impl Default for VegaLiteConfig {
985 fn default() -> Self {
986 Self {
987 version: "5.8.0".to_string(),
988 vega_url: "https://cdn.jsdelivr.net/npm/vega@5".to_string(),
989 vega_lite_url: "https://cdn.jsdelivr.net/npm/vega-lite@5".to_string(),
990 vega_embed_url: "https://cdn.jsdelivr.net/npm/vega-embed@6".to_string(),
991 theme: "default".to_string(),
992 default_config: VegaLiteDefaultConfig {
993 actions: vec![
994 "export".to_string(),
995 "source".to_string(),
996 "compiled".to_string(),
997 "editor".to_string(),
998 ],
999 tooltip: true,
1000 renderer: "canvas".to_string(),
1001 },
1002 }
1003 }
1004}
1005
1006impl VegaLiteBackend {
1007 pub fn new(config: VegaLiteConfig) -> Self {
1009 Self { config }
1010 }
1011
1012 fn generate_feature_importance_vega(
1014 &self,
1015 data: &FeatureImportanceData,
1016 config: &BackendConfig,
1017 ) -> SklResult<String> {
1018 let vega_data: Vec<serde_json::Value> = data.importance_values.iter().enumerate().map(|(i, &val)| {
1019 serde_json::json!({
1020 "feature": data.feature_names.get(i).cloned().unwrap_or_else(|| format!("Feature {}", i)),
1021 "importance": val,
1022 "std": data.std_values.as_ref().map(|s| s.get(i).copied().unwrap_or(0.0))
1023 })
1024 }).collect();
1025
1026 let spec = serde_json::json!({
1027 "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
1028 "description": "Feature Importance Visualization",
1029 "width": config.width,
1030 "height": config.height,
1031 "data": {
1032 "values": vega_data
1033 },
1034 "mark": {
1035 "type": "bar",
1036 "tooltip": true
1037 },
1038 "encoding": {
1039 "x": {
1040 "field": "feature",
1041 "type": "nominal",
1042 "axis": {
1043 "title": "Features",
1044 "labelAngle": -45
1045 }
1046 },
1047 "y": {
1048 "field": "importance",
1049 "type": "quantitative",
1050 "axis": {
1051 "title": "Importance Score"
1052 }
1053 },
1054 "color": {
1055 "value": "steelblue"
1056 },
1057 "tooltip": [
1058 {"field": "feature", "type": "nominal"},
1059 {"field": "importance", "type": "quantitative", "format": ".4f"}
1060 ]
1061 },
1062 "title": {
1063 "text": "Feature Importance",
1064 "fontSize": 16,
1065 "anchor": "start"
1066 }
1067 });
1068
1069 let embed_options = serde_json::json!({
1070 "actions": self.config.default_config.actions,
1071 "tooltip": self.config.default_config.tooltip,
1072 "renderer": self.config.default_config.renderer,
1073 "theme": self.config.theme
1074 });
1075
1076 let js_code = format!(
1077 r#"
1078 const spec = {};
1079 const opt = {};
1080 vegaEmbed('#vega-div', spec, opt).catch(console.error);
1081 "#,
1082 serde_json::to_string_pretty(&spec)
1083 .map_err(|e| SklearsError::Other(format!("JSON error: {}", e)))?,
1084 serde_json::to_string_pretty(&embed_options)
1085 .map_err(|e| SklearsError::Other(format!("JSON error: {}", e)))?
1086 );
1087
1088 Ok(js_code)
1089 }
1090
1091 fn generate_complete_html(&self, js_code: &str, title: &str) -> String {
1093 format!(
1094 r#"<!DOCTYPE html>
1095<html>
1096<head>
1097 <meta charset="utf-8">
1098 <title>{}</title>
1099 <script src="{}"></script>
1100 <script src="{}"></script>
1101 <script src="{}"></script>
1102 <style>
1103 body {{
1104 font-family: Arial, sans-serif;
1105 margin: 20px;
1106 }}
1107 #vega-div {{
1108 width: 100%;
1109 height: 100%;
1110 }}
1111 </style>
1112</head>
1113<body>
1114 <h1>{}</h1>
1115 <div id="vega-div"></div>
1116 <script>
1117 {}
1118 </script>
1119</body>
1120</html>"#,
1121 title,
1122 self.config.vega_url,
1123 self.config.vega_lite_url,
1124 self.config.vega_embed_url,
1125 title,
1126 js_code
1127 )
1128 }
1129}
1130
1131impl VisualizationBackend for VegaLiteBackend {
1132 fn render_feature_importance(
1133 &self,
1134 data: &FeatureImportanceData,
1135 config: &BackendConfig,
1136 ) -> SklResult<RenderedVisualization> {
1137 let js_code = self.generate_feature_importance_vega(data, config)?;
1138
1139 let content = match config.format {
1140 OutputFormat::Html => self.generate_complete_html(&js_code, "Feature Importance"),
1141 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
1142 "type": "vega-lite",
1143 "javascript": js_code,
1144 "title": "Feature Importance"
1145 }))
1146 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
1147 _ => {
1148 return Err(SklearsError::InvalidInput(
1149 "Unsupported format for Vega-Lite backend".to_string(),
1150 ))
1151 }
1152 };
1153
1154 Ok(RenderedVisualization {
1155 content,
1156 format: config.format,
1157 metadata: VisualizationMetadata {
1158 backend: "vega-lite".to_string(),
1159 render_time_ms: 0,
1160 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
1163 },
1164 binary_data: None,
1165 })
1166 }
1167
1168 fn render_shap_plot(
1169 &self,
1170 _data: &ShapData,
1171 _config: &BackendConfig,
1172 ) -> SklResult<RenderedVisualization> {
1173 Err(SklearsError::NotImplemented(
1175 "SHAP plot for Vega-Lite backend not yet implemented".to_string(),
1176 ))
1177 }
1178
1179 fn render_partial_dependence(
1180 &self,
1181 _data: &PartialDependenceData,
1182 _config: &BackendConfig,
1183 ) -> SklResult<RenderedVisualization> {
1184 Err(SklearsError::NotImplemented(
1186 "Partial dependence plot for Vega-Lite backend not yet implemented".to_string(),
1187 ))
1188 }
1189
1190 fn render_comparative_plot(
1191 &self,
1192 _data: &ComparativeData,
1193 _config: &BackendConfig,
1194 ) -> SklResult<RenderedVisualization> {
1195 Err(SklearsError::NotImplemented(
1197 "Comparative plot for Vega-Lite backend not yet implemented".to_string(),
1198 ))
1199 }
1200
1201 fn render_custom_plot(
1202 &self,
1203 data: &CustomPlotData,
1204 config: &BackendConfig,
1205 ) -> SklResult<RenderedVisualization> {
1206 let spec = data.data.get("vega_lite_spec").ok_or_else(|| {
1208 SklearsError::InvalidInput("Vega-Lite spec not found in custom plot data".to_string())
1209 })?;
1210
1211 let js_code = format!(
1212 r#"
1213 const spec = {};
1214 vegaEmbed('#vega-div', spec, {{}}).catch(console.error);
1215 "#,
1216 serde_json::to_string_pretty(spec)
1217 .map_err(|e| SklearsError::Other(format!("JSON error: {}", e)))?
1218 );
1219
1220 let content = match config.format {
1221 OutputFormat::Html => self.generate_complete_html(&js_code, "Custom Vega-Lite Plot"),
1222 OutputFormat::Json => serde_json::to_string_pretty(&serde_json::json!({
1223 "type": "vega-lite",
1224 "javascript": js_code,
1225 "title": "Custom Vega-Lite Plot"
1226 }))
1227 .map_err(|e| SklearsError::Other(format!("JSON serialization error: {}", e)))?,
1228 _ => {
1229 return Err(SklearsError::InvalidInput(
1230 "Unsupported format for Vega-Lite backend".to_string(),
1231 ))
1232 }
1233 };
1234
1235 Ok(RenderedVisualization {
1236 content,
1237 format: config.format,
1238 metadata: VisualizationMetadata {
1239 backend: "vega-lite".to_string(),
1240 render_time_ms: 0,
1241 file_size_bytes: 0, data_points: 1, created_at: chrono::Utc::now(),
1244 },
1245 binary_data: None,
1246 })
1247 }
1248
1249 fn name(&self) -> &str {
1250 "vega-lite"
1251 }
1252
1253 fn supported_formats(&self) -> Vec<OutputFormat> {
1254 vec![OutputFormat::Html, OutputFormat::Json]
1255 }
1256
1257 fn supports_interactivity(&self) -> bool {
1258 true
1259 }
1260
1261 fn capabilities(&self) -> BackendCapabilities {
1262 BackendCapabilities {
1263 formats: vec![OutputFormat::Html, OutputFormat::Json],
1264 interactive: true,
1265 animations: true,
1266 three_d: false,
1267 custom_themes: true,
1268 real_time_updates: false,
1269 max_data_points: Some(100000),
1270 }
1271 }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276 use super::*;
1277 use crate::visualization_backend::OutputFormat;
1278
1279 fn create_test_feature_importance_data() -> FeatureImportanceData {
1280 FeatureImportanceData {
1281 feature_names: vec![
1282 "Feature 0".to_string(),
1283 "Feature 1".to_string(),
1284 "Feature 2".to_string(),
1285 ],
1286 importance_values: vec![0.5, 0.3, 0.2],
1287 std_values: Some(vec![0.1, 0.05, 0.03]),
1288 plot_type: crate::visualization_backend::FeatureImportanceType::Bar,
1289 title: "Feature Importance".to_string(),
1290 x_label: "Features".to_string(),
1291 y_label: "Importance Score".to_string(),
1292 }
1293 }
1294
1295 fn create_test_config() -> BackendConfig {
1296 BackendConfig {
1297 format: OutputFormat::Html,
1298 width: 800,
1299 height: 600,
1300 ..Default::default()
1301 }
1302 }
1303
1304 #[test]
1305 fn test_plotly_backend_creation() {
1306 let config = PlotlyConfig::default();
1307 let backend = PlotlyBackend::new(config);
1308 assert_eq!(backend.name(), "plotly");
1309 assert!(backend.supports_interactivity());
1310 }
1311
1312 #[test]
1313 fn test_plotly_feature_importance_rendering() {
1314 let backend = PlotlyBackend::new(PlotlyConfig::default());
1315 let data = create_test_feature_importance_data();
1316 let config = create_test_config();
1317
1318 let result = backend.render_feature_importance(&data, &config);
1319 assert!(result.is_ok());
1320
1321 let visualization = result.unwrap();
1322 assert_eq!(visualization.format, OutputFormat::Html);
1323 assert!(visualization.content.contains("Plotly.newPlot"));
1324 assert_eq!(visualization.metadata.backend, "plotly");
1325 }
1326
1327 #[test]
1328 fn test_d3_backend_creation() {
1329 let config = D3Config::default();
1330 let backend = D3Backend::new(config);
1331 assert_eq!(backend.name(), "d3");
1332 assert!(backend.supports_interactivity());
1333 }
1334
1335 #[test]
1336 fn test_d3_feature_importance_rendering() {
1337 let backend = D3Backend::new(D3Config::default());
1338 let data = create_test_feature_importance_data();
1339 let config = create_test_config();
1340
1341 let result = backend.render_feature_importance(&data, &config);
1342 assert!(result.is_ok());
1343
1344 let visualization = result.unwrap();
1345 assert_eq!(visualization.format, OutputFormat::Html);
1346 assert!(visualization.content.contains("d3.select"));
1347 assert_eq!(visualization.metadata.backend, "d3");
1348 }
1349
1350 #[test]
1351 fn test_vega_lite_backend_creation() {
1352 let config = VegaLiteConfig::default();
1353 let backend = VegaLiteBackend::new(config);
1354 assert_eq!(backend.name(), "vega-lite");
1355 assert!(backend.supports_interactivity());
1356 }
1357
1358 #[test]
1359 fn test_vega_lite_feature_importance_rendering() {
1360 let backend = VegaLiteBackend::new(VegaLiteConfig::default());
1361 let data = create_test_feature_importance_data();
1362 let config = create_test_config();
1363
1364 let result = backend.render_feature_importance(&data, &config);
1365 assert!(result.is_ok());
1366
1367 let visualization = result.unwrap();
1368 assert_eq!(visualization.format, OutputFormat::Html);
1369 assert!(visualization.content.contains("vegaEmbed"));
1370 assert_eq!(visualization.metadata.backend, "vega-lite");
1371 }
1372
1373 #[test]
1374 fn test_backend_capabilities() {
1375 let plotly = PlotlyBackend::new(PlotlyConfig::default());
1376 let d3 = D3Backend::new(D3Config::default());
1377 let vega = VegaLiteBackend::new(VegaLiteConfig::default());
1378
1379 let plotly_caps = plotly.capabilities();
1380 let d3_caps = d3.capabilities();
1381 let vega_caps = vega.capabilities();
1382
1383 assert!(plotly_caps.interactive);
1384 assert!(d3_caps.interactive);
1385 assert!(vega_caps.interactive);
1386
1387 assert!(plotly_caps.animations);
1388 assert!(d3_caps.animations);
1389 assert!(vega_caps.animations);
1390 }
1391
1392 #[test]
1393 fn test_supported_formats() {
1394 let plotly = PlotlyBackend::new(PlotlyConfig::default());
1395 let formats = plotly.supported_formats();
1396
1397 assert!(formats.contains(&OutputFormat::Html));
1398 assert!(formats.contains(&OutputFormat::Json));
1399 }
1400
1401 #[test]
1402 fn test_json_output_format() {
1403 let backend = PlotlyBackend::new(PlotlyConfig::default());
1404 let data = create_test_feature_importance_data();
1405 let mut config = create_test_config();
1406 config.format = OutputFormat::Json;
1407
1408 let result = backend.render_feature_importance(&data, &config);
1409 assert!(result.is_ok());
1410
1411 let visualization = result.unwrap();
1412 assert_eq!(visualization.format, OutputFormat::Json);
1413
1414 let json_value: serde_json::Value = serde_json::from_str(&visualization.content).unwrap();
1416 assert!(json_value.get("type").is_some());
1417 assert!(json_value.get("javascript").is_some());
1418 }
1419}