Skip to main content

trustformers_debug/
visualization_plugins.rs

1//! Custom Visualization Plugin System
2//!
3//! This module provides an extensible plugin system for custom visualizations,
4//! allowing users to create and register their own visualization tools.
5
6use anyhow::Result;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12
13/// Visualization data that can be passed to plugins
14#[derive(Debug, Clone)]
15pub enum VisualizationData {
16    /// 1D array data
17    Array1D(Vec<f64>),
18    /// 2D array data
19    Array2D(Vec<Vec<f64>>),
20    /// Tensor data with shape information
21    Tensor { data: Vec<f64>, shape: Vec<usize> },
22    /// Key-value pairs (for metadata, metrics, etc.)
23    KeyValue(HashMap<String, String>),
24    /// Time series data
25    TimeSeries {
26        timestamps: Vec<f64>,
27        values: Vec<f64>,
28        labels: Vec<String>,
29    },
30    /// Custom JSON data
31    Json(serde_json::Value),
32}
33
34/// Output format for visualizations
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum OutputFormat {
37    /// PNG image (not supported — requires binary encoder; use Svg instead)
38    Png,
39    /// SVG vector graphics
40    Svg,
41    /// HTML interactive visualization (SVG embedded in HTML)
42    Html,
43    /// Plain text/ASCII
44    Text,
45    /// JSON data
46    Json,
47    /// CSV data
48    Csv,
49}
50
51/// Plugin capabilities and metadata
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct PluginMetadata {
54    /// Plugin name
55    pub name: String,
56    /// Plugin version
57    pub version: String,
58    /// Plugin description
59    pub description: String,
60    /// Author
61    pub author: String,
62    /// Supported input data types
63    pub supported_inputs: Vec<String>,
64    /// Supported output formats
65    pub supported_outputs: Vec<OutputFormat>,
66    /// Tags/categories
67    pub tags: Vec<String>,
68}
69
70/// Configuration for visualization plugins
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct PluginConfig {
73    /// Output format
74    pub output_format: OutputFormat,
75    /// Output path (file or directory)
76    pub output_path: Option<String>,
77    /// Width in pixels (for image outputs)
78    pub width: usize,
79    /// Height in pixels (for image outputs)
80    pub height: usize,
81    /// Color scheme
82    pub color_scheme: String,
83    /// Additional custom parameters
84    pub custom_params: HashMap<String, serde_json::Value>,
85}
86
87impl Default for PluginConfig {
88    fn default() -> Self {
89        Self {
90            output_format: OutputFormat::Svg,
91            output_path: None,
92            width: 800,
93            height: 600,
94            color_scheme: "viridis".to_string(),
95            custom_params: HashMap::new(),
96        }
97    }
98}
99
100/// Result of plugin execution
101#[derive(Debug)]
102pub struct PluginResult {
103    /// Success status
104    pub success: bool,
105    /// Output file path (if file was created)
106    pub output_path: Option<String>,
107    /// Raw output data (if applicable)
108    pub output_data: Option<Vec<u8>>,
109    /// Metadata about the visualization
110    pub metadata: HashMap<String, String>,
111    /// Error message (if failed)
112    pub error: Option<String>,
113}
114
115/// Trait for visualization plugins
116///
117/// Implement this trait to create custom visualization plugins.
118pub trait VisualizationPlugin: Send + Sync {
119    /// Get plugin metadata
120    fn metadata(&self) -> PluginMetadata;
121
122    /// Execute the visualization
123    ///
124    /// # Arguments
125    /// * `data` - Input data to visualize
126    /// * `config` - Configuration for the visualization
127    ///
128    /// # Returns
129    /// Result containing the plugin output
130    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult>;
131
132    /// Validate input data
133    ///
134    /// # Arguments
135    /// * `data` - Data to validate
136    ///
137    /// # Returns
138    /// True if data is valid for this plugin
139    fn validate(&self, data: &VisualizationData) -> bool {
140        // Default implementation: accept all data
141        let _ = data;
142        true
143    }
144
145    /// Get configuration schema (for UI generation)
146    fn config_schema(&self) -> serde_json::Value {
147        serde_json::json!({
148            "type": "object",
149            "properties": {}
150        })
151    }
152}
153
154/// Plugin manager for registering and executing visualization plugins
155pub struct PluginManager {
156    /// Registered plugins
157    plugins: Arc<RwLock<HashMap<String, Box<dyn VisualizationPlugin>>>>,
158    /// Plugin execution history
159    history: Arc<RwLock<Vec<PluginExecution>>>,
160}
161
162/// Record of plugin execution
163#[derive(Debug, Clone)]
164struct PluginExecution {
165    plugin_name: String,
166    timestamp: std::time::SystemTime,
167    success: bool,
168    duration_ms: u128,
169}
170
171impl Default for PluginManager {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl PluginManager {
178    /// Create a new plugin manager
179    pub fn new() -> Self {
180        let manager = Self {
181            plugins: Arc::new(RwLock::new(HashMap::new())),
182            history: Arc::new(RwLock::new(Vec::new())),
183        };
184
185        // Register built-in plugins
186        manager.register_builtin_plugins();
187
188        manager
189    }
190
191    /// Register built-in plugins
192    fn register_builtin_plugins(&self) {
193        // Register histogram plugin
194        self.register_plugin(Box::new(HistogramPlugin)).ok();
195
196        // Register heatmap plugin
197        self.register_plugin(Box::new(HeatmapPlugin)).ok();
198
199        // Register line plot plugin
200        self.register_plugin(Box::new(LinePlotPlugin)).ok();
201
202        // Register scatter plot plugin
203        self.register_plugin(Box::new(ScatterPlotPlugin)).ok();
204    }
205
206    /// Register a new plugin
207    ///
208    /// # Arguments
209    /// * `plugin` - Plugin to register
210    pub fn register_plugin(&self, plugin: Box<dyn VisualizationPlugin>) -> Result<()> {
211        let name = plugin.metadata().name.clone();
212
213        self.plugins.write().insert(name.clone(), plugin);
214
215        tracing::info!(plugin_name = %name, "Registered visualization plugin");
216
217        Ok(())
218    }
219
220    /// Unregister a plugin
221    ///
222    /// # Arguments
223    /// * `name` - Name of plugin to unregister
224    pub fn unregister_plugin(&self, name: &str) -> Result<()> {
225        self.plugins.write().remove(name);
226
227        tracing::info!(plugin_name = %name, "Unregistered visualization plugin");
228
229        Ok(())
230    }
231
232    /// Get list of registered plugins
233    pub fn list_plugins(&self) -> Vec<PluginMetadata> {
234        self.plugins.read().values().map(|p| p.metadata()).collect()
235    }
236
237    /// Execute a plugin
238    ///
239    /// # Arguments
240    /// * `plugin_name` - Name of plugin to execute
241    /// * `data` - Input data
242    /// * `config` - Configuration
243    pub fn execute(
244        &self,
245        plugin_name: &str,
246        data: VisualizationData,
247        config: PluginConfig,
248    ) -> Result<PluginResult> {
249        let start_time = std::time::Instant::now();
250
251        let result = {
252            let plugins = self.plugins.read();
253            let plugin = plugins
254                .get(plugin_name)
255                .ok_or_else(|| anyhow::anyhow!("Plugin not found: {}", plugin_name))?;
256
257            // Validate data
258            if !plugin.validate(&data) {
259                anyhow::bail!("Invalid data for plugin: {}", plugin_name);
260            }
261
262            plugin.execute(data, config)?
263        };
264
265        let duration = start_time.elapsed().as_millis();
266
267        // Record execution
268        self.history.write().push(PluginExecution {
269            plugin_name: plugin_name.to_string(),
270            timestamp: std::time::SystemTime::now(),
271            success: result.success,
272            duration_ms: duration,
273        });
274
275        Ok(result)
276    }
277
278    /// Get plugin by name
279    pub fn get_plugin(&self, name: &str) -> Option<PluginMetadata> {
280        self.plugins.read().get(name).map(|p| p.metadata())
281    }
282
283    /// Get execution history
284    pub fn get_history(&self) -> Vec<String> {
285        self.history
286            .read()
287            .iter()
288            .map(|e| {
289                format!(
290                    "{}: {} ({}ms) - {}",
291                    e.timestamp.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs(),
292                    e.plugin_name,
293                    e.duration_ms,
294                    if e.success { "success" } else { "failed" }
295                )
296            })
297            .collect()
298    }
299}
300
301// ============================================================================
302// SVG rendering — pure handwritten, no external drawing dep
303// ============================================================================
304
305mod svg_render {
306    //! Pure-Rust handwritten SVG generators.
307    //!
308    //! Each function returns a complete, self-contained SVG document.
309    //! Layout constants and a basic coordinate mapper are shared across helpers.
310
311    use super::PluginConfig;
312
313    // Layout constants (pixels inside the SVG viewBox)
314    const MARGIN_TOP: f64 = 40.0;
315    const MARGIN_BOTTOM: f64 = 50.0;
316    const MARGIN_LEFT: f64 = 60.0;
317    const MARGIN_RIGHT: f64 = 20.0;
318    const AXIS_TICK_LEN: f64 = 5.0;
319
320    // SVG attribute strings (literal, not format arguments)
321    const FONT_ATTR: &str = r#"font-family="sans-serif""#;
322    const AXIS_COLOR: &str = "#555";
323    const BAR_COLOR: &str = "#4878CF";
324
325    /// Clamp a value to [lo, hi]
326    fn clamp_f64(v: f64, lo: f64, hi: f64) -> f64 {
327        if v < lo {
328            lo
329        } else if v > hi {
330            hi
331        } else {
332            v
333        }
334    }
335
336    /// Map a data value to a pixel coordinate within the plot area.
337    fn map_x(v: f64, data_min: f64, data_max: f64, px_left: f64, px_right: f64) -> f64 {
338        let range = data_max - data_min;
339        if range == 0.0 {
340            return (px_left + px_right) / 2.0;
341        }
342        let t = (v - data_min) / range;
343        clamp_f64(px_left + t * (px_right - px_left), px_left, px_right)
344    }
345
346    /// Map a data value to a pixel y-coordinate (data grows up, SVG y grows down).
347    fn map_y(v: f64, data_min: f64, data_max: f64, px_top: f64, px_bottom: f64) -> f64 {
348        let range = data_max - data_min;
349        if range == 0.0 {
350            return (px_top + px_bottom) / 2.0;
351        }
352        let t = (v - data_min) / range;
353        clamp_f64(px_bottom - t * (px_bottom - px_top), px_top, px_bottom)
354    }
355
356    /// SVG header with explicit width/height and a white background.
357    fn svg_open(width: usize, height: usize) -> String {
358        format!(
359            "<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"{w}\" height=\"{h}\" viewBox=\"0 0 {w} {h}\">\n\
360             <rect width=\"{w}\" height=\"{h}\" fill=\"white\"/>\n",
361            w = width,
362            h = height,
363        )
364    }
365
366    fn svg_close() -> &'static str {
367        "</svg>"
368    }
369
370    /// Render a title centred at the top of the SVG.
371    fn svg_title(text: &str, width: usize) -> String {
372        let cx = width / 2;
373        let escaped = escape_xml(text);
374        format!(
375            "<text x=\"{cx}\" y=\"24\" text-anchor=\"middle\" {font} font-size=\"16\" fill=\"#333\">{text}</text>\n",
376            cx = cx,
377            font = FONT_ATTR,
378            text = escaped,
379        )
380    }
381
382    /// Render axis lines, tick marks and numeric tick labels.
383    #[allow(clippy::too_many_arguments)]
384    fn svg_axes(
385        px_left: f64,
386        px_right: f64,
387        px_top: f64,
388        px_bottom: f64,
389        x_min: f64,
390        x_max: f64,
391        y_min: f64,
392        y_max: f64,
393        x_label: &str,
394        y_label: &str,
395        _width: usize,
396        height: usize,
397        n_ticks: usize,
398    ) -> String {
399        let mut out = String::new();
400        let c = AXIS_COLOR;
401
402        // Vertical axis line (left)
403        out.push_str(&format!(
404            "<line x1=\"{x1:.2}\" y1=\"{y1:.2}\" x2=\"{x2:.2}\" y2=\"{y2:.2}\" stroke=\"{c}\" stroke-width=\"1\"/>\n",
405            x1 = px_left, y1 = px_top, x2 = px_left, y2 = px_bottom, c = c,
406        ));
407        // Horizontal axis line (bottom)
408        out.push_str(&format!(
409            "<line x1=\"{x1:.2}\" y1=\"{y1:.2}\" x2=\"{x2:.2}\" y2=\"{y2:.2}\" stroke=\"{c}\" stroke-width=\"1\"/>\n",
410            x1 = px_left, y1 = px_bottom, x2 = px_right, y2 = px_bottom, c = c,
411        ));
412
413        // X ticks
414        let x_range = x_max - x_min;
415        for i in 0..=n_ticks {
416            let frac = i as f64 / n_ticks as f64;
417            let val = x_min + frac * x_range;
418            let px = px_left + frac * (px_right - px_left);
419            let ty = px_bottom + AXIS_TICK_LEN + 12.0;
420            let y2 = px_bottom + AXIS_TICK_LEN;
421            out.push_str(&format!(
422                "<line x1=\"{px:.2}\" y1=\"{y1:.2}\" x2=\"{px:.2}\" y2=\"{y2:.2}\" stroke=\"{c}\" stroke-width=\"1\"/>\n\
423                 <text x=\"{px:.2}\" y=\"{ty:.2}\" text-anchor=\"middle\" {font} font-size=\"10\" fill=\"{c}\">{val:.2}</text>\n",
424                px = px, y1 = px_bottom, y2 = y2, ty = ty, c = c, font = FONT_ATTR, val = val,
425            ));
426        }
427
428        // Y ticks
429        let y_range = y_max - y_min;
430        for i in 0..=n_ticks {
431            let frac = i as f64 / n_ticks as f64;
432            let val = y_min + frac * y_range;
433            let py = px_bottom - frac * (px_bottom - px_top);
434            let tick_x1 = px_left - AXIS_TICK_LEN;
435            let tx = tick_x1 - 2.0;
436            let py_t = py + 4.0;
437            out.push_str(&format!(
438                "<line x1=\"{tx1:.2}\" y1=\"{py:.2}\" x2=\"{x2:.2}\" y2=\"{py:.2}\" stroke=\"{c}\" stroke-width=\"1\"/>\n\
439                 <text x=\"{tx:.2}\" y=\"{pyt:.2}\" text-anchor=\"end\" {font} font-size=\"10\" fill=\"{c}\">{val:.2}</text>\n",
440                tx1 = tick_x1, py = py, x2 = px_left, tx = tx, pyt = py_t,
441                c = c, font = FONT_ATTR, val = val,
442            ));
443        }
444
445        // Axis labels
446        if !x_label.is_empty() {
447            let lx = (px_left + px_right) / 2.0;
448            let ly = height as f64 - 4.0;
449            let label = escape_xml(x_label);
450            out.push_str(&format!(
451                "<text x=\"{lx:.2}\" y=\"{ly:.2}\" text-anchor=\"middle\" {font} font-size=\"12\" fill=\"{c}\">{label}</text>\n",
452                lx = lx, ly = ly, font = FONT_ATTR, c = c, label = label,
453            ));
454        }
455        if !y_label.is_empty() {
456            let ry_x = -((px_top + px_bottom) / 2.0);
457            let label = escape_xml(y_label);
458            out.push_str(&format!(
459                "<text transform=\"rotate(-90)\" x=\"{rx:.2}\" y=\"14\" text-anchor=\"middle\" {font} font-size=\"12\" fill=\"{c}\">{label}</text>\n",
460                rx = ry_x, font = FONT_ATTR, c = c, label = label,
461            ));
462        }
463
464        out
465    }
466
467    /// Escape XML special characters in text content.
468    fn escape_xml(s: &str) -> String {
469        s.replace('&', "&amp;")
470            .replace('<', "&lt;")
471            .replace('>', "&gt;")
472            .replace('"', "&quot;")
473    }
474
475    // -------------------------------------------------------------------------
476    // Public renderers
477    // -------------------------------------------------------------------------
478
479    /// Render a histogram. `bins` is a slice of `(bin_left, bin_right, count)`.
480    pub fn histogram(bins: &[(f64, f64, usize)], config: &PluginConfig) -> String {
481        let w = config.width;
482        let h = config.height;
483        let px_left = MARGIN_LEFT;
484        let px_right = w as f64 - MARGIN_RIGHT;
485        let px_top = MARGIN_TOP;
486        let px_bottom = h as f64 - MARGIN_BOTTOM;
487
488        let title = config
489            .custom_params
490            .get("title")
491            .and_then(|v| v.as_str())
492            .unwrap_or("Histogram");
493        let x_label =
494            config.custom_params.get("x_label").and_then(|v| v.as_str()).unwrap_or("Value");
495        let y_label =
496            config.custom_params.get("y_label").and_then(|v| v.as_str()).unwrap_or("Count");
497
498        let max_count = bins.iter().map(|b| b.2).max().unwrap_or(1).max(1);
499        let x_min = bins.first().map(|b| b.0).unwrap_or(0.0);
500        let x_max = bins.last().map(|b| b.1).unwrap_or(1.0);
501
502        let mut out = svg_open(w, h);
503        out.push_str(&svg_title(title, w));
504        out.push_str(&svg_axes(
505            px_left,
506            px_right,
507            px_top,
508            px_bottom,
509            x_min,
510            x_max,
511            0.0,
512            max_count as f64,
513            x_label,
514            y_label,
515            w,
516            h,
517            5,
518        ));
519
520        // Draw bars
521        let fill = BAR_COLOR;
522        for (bin_left, bin_right, count) in bins {
523            let bx1 = map_x(*bin_left, x_min, x_max, px_left, px_right);
524            let bx2 = map_x(*bin_right, x_min, x_max, px_left, px_right);
525            let by_top = map_y(*count as f64, 0.0, max_count as f64, px_top, px_bottom);
526            let bar_h = px_bottom - by_top;
527            let bar_w = (bx2 - bx1).max(1.0);
528            out.push_str(&format!(
529                "<rect x=\"{x:.2}\" y=\"{y:.2}\" width=\"{bw:.2}\" height=\"{bh:.2}\" fill=\"{fill}\" stroke=\"white\" stroke-width=\"1\"/>\n",
530                x = bx1, y = by_top, bw = bar_w, bh = bar_h, fill = fill,
531            ));
532        }
533
534        out.push_str(svg_close());
535        out
536    }
537
538    /// Render a heatmap. `values` is row-major with dimensions `rows × cols`.
539    pub fn heatmap(rows: usize, cols: usize, values: &[f64], config: &PluginConfig) -> String {
540        let w = config.width;
541        let h = config.height;
542        let px_left = MARGIN_LEFT;
543        let px_right = w as f64 - MARGIN_RIGHT;
544        let px_top = MARGIN_TOP;
545        let px_bottom = h as f64 - MARGIN_BOTTOM;
546
547        let title = config.custom_params.get("title").and_then(|v| v.as_str()).unwrap_or("Heatmap");
548
549        let cell_w = if cols > 0 { (px_right - px_left) / cols as f64 } else { 1.0 };
550        let cell_h = if rows > 0 { (px_bottom - px_top) / rows as f64 } else { 1.0 };
551
552        let (v_min, v_max) =
553            values.iter().copied().fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), v| {
554                (lo.min(v), hi.max(v))
555            });
556        let v_range = (v_max - v_min).max(f64::EPSILON);
557
558        let mut out = svg_open(w, h);
559        out.push_str(&svg_title(title, w));
560
561        // Draw cells
562        for row_idx in 0..rows {
563            for col_idx in 0..cols {
564                let idx = row_idx * cols + col_idx;
565                let val = values.get(idx).copied().unwrap_or(0.0);
566                let t = ((val - v_min) / v_range).clamp(0.0, 1.0);
567                // Viridis-like gradient: dark purple to yellow
568                let red = (255.0 * (t * t)) as u8;
569                let green = (255.0 * t * (1.0 - t * 0.5)) as u8;
570                let blue = (255.0 * (1.0 - t)) as u8;
571                let cx = px_left + col_idx as f64 * cell_w;
572                let cy = px_top + row_idx as f64 * cell_h;
573                out.push_str(&format!(
574                    "<rect x=\"{x:.2}\" y=\"{y:.2}\" width=\"{cw:.2}\" height=\"{ch:.2}\" fill=\"rgb({r},{g},{b})\"/>\n",
575                    x = cx, y = cy, cw = cell_w, ch = cell_h,
576                    r = red, g = green, b = blue,
577                ));
578            }
579        }
580
581        out.push_str(svg_close());
582        out
583    }
584
585    /// Render a line plot. `points` is a slice of `(x, y)` pairs, ordered by x.
586    pub fn line_plot(points: &[(f64, f64)], config: &PluginConfig) -> String {
587        let w = config.width;
588        let h = config.height;
589        let px_left = MARGIN_LEFT;
590        let px_right = w as f64 - MARGIN_RIGHT;
591        let px_top = MARGIN_TOP;
592        let px_bottom = h as f64 - MARGIN_BOTTOM;
593
594        let title = config
595            .custom_params
596            .get("title")
597            .and_then(|v| v.as_str())
598            .unwrap_or("Line Plot");
599        let x_label = config.custom_params.get("x_label").and_then(|v| v.as_str()).unwrap_or("X");
600        let y_label = config.custom_params.get("y_label").and_then(|v| v.as_str()).unwrap_or("Y");
601
602        let (x_min, x_max, y_min, y_max) = data_bounds(points);
603
604        let mut out = svg_open(w, h);
605        out.push_str(&svg_title(title, w));
606        out.push_str(&svg_axes(
607            px_left, px_right, px_top, px_bottom, x_min, x_max, y_min, y_max, x_label, y_label, w,
608            h, 5,
609        ));
610
611        if !points.is_empty() {
612            // Build a polyline
613            let pts: Vec<String> = points
614                .iter()
615                .map(|(x, y)| {
616                    let px = map_x(*x, x_min, x_max, px_left, px_right);
617                    let py = map_y(*y, y_min, y_max, px_top, px_bottom);
618                    format!("{:.2},{:.2}", px, py)
619                })
620                .collect();
621            let stroke = BAR_COLOR;
622            out.push_str(&format!(
623                "<polyline points=\"{pts}\" fill=\"none\" stroke=\"{stroke}\" stroke-width=\"2\"/>\n",
624                pts = pts.join(" "),
625                stroke = stroke,
626            ));
627        }
628
629        out.push_str(svg_close());
630        out
631    }
632
633    /// Render a scatter plot. `points` is a slice of `(x, y)` pairs.
634    pub fn scatter(points: &[(f64, f64)], config: &PluginConfig) -> String {
635        let w = config.width;
636        let h = config.height;
637        let px_left = MARGIN_LEFT;
638        let px_right = w as f64 - MARGIN_RIGHT;
639        let px_top = MARGIN_TOP;
640        let px_bottom = h as f64 - MARGIN_BOTTOM;
641
642        let title = config
643            .custom_params
644            .get("title")
645            .and_then(|v| v.as_str())
646            .unwrap_or("Scatter Plot");
647        let x_label = config.custom_params.get("x_label").and_then(|v| v.as_str()).unwrap_or("X");
648        let y_label = config.custom_params.get("y_label").and_then(|v| v.as_str()).unwrap_or("Y");
649
650        let (x_min, x_max, y_min, y_max) = data_bounds(points);
651
652        let mut out = svg_open(w, h);
653        out.push_str(&svg_title(title, w));
654        out.push_str(&svg_axes(
655            px_left, px_right, px_top, px_bottom, x_min, x_max, y_min, y_max, x_label, y_label, w,
656            h, 5,
657        ));
658
659        let fill = BAR_COLOR;
660        for (x, y) in points {
661            let px = map_x(*x, x_min, x_max, px_left, px_right);
662            let py = map_y(*y, y_min, y_max, px_top, px_bottom);
663            out.push_str(&format!(
664                "<circle cx=\"{cx:.2}\" cy=\"{cy:.2}\" r=\"4\" fill=\"{fill}\" fill-opacity=\"0.7\"/>\n",
665                cx = px, cy = py, fill = fill,
666            ));
667        }
668
669        out.push_str(svg_close());
670        out
671    }
672
673    // -------------------------------------------------------------------------
674    // JSON / CSV renderers
675    // -------------------------------------------------------------------------
676
677    /// Render histogram data as JSON.
678    pub fn histogram_json(
679        bins: &[(f64, f64, usize)],
680        min: f64,
681        max: f64,
682        n_values: usize,
683    ) -> String {
684        let bins_arr: Vec<serde_json::Value> = bins
685            .iter()
686            .map(|(lo, hi, cnt)| {
687                serde_json::json!({
688                    "bin_left": lo,
689                    "bin_right": hi,
690                    "count": cnt,
691                })
692            })
693            .collect();
694        serde_json::to_string_pretty(&serde_json::json!({
695            "type": "histogram",
696            "n_values": n_values,
697            "min": min,
698            "max": max,
699            "bins": bins_arr,
700        }))
701        .unwrap_or_else(|_| "{}".to_string())
702    }
703
704    /// Render histogram data as CSV.
705    pub fn histogram_csv(bins: &[(f64, f64, usize)]) -> String {
706        let mut out = String::from("bin_left,bin_right,count\n");
707        for (lo, hi, cnt) in bins {
708            out.push_str(&format!("{},{},{}\n", lo, hi, cnt));
709        }
710        out
711    }
712
713    /// Render 2D point data as JSON.
714    pub fn points_json(kind: &str, points: &[(f64, f64)]) -> String {
715        let pts: Vec<serde_json::Value> =
716            points.iter().map(|(x, y)| serde_json::json!({ "x": x, "y": y })).collect();
717        serde_json::to_string_pretty(&serde_json::json!({
718            "type": kind,
719            "n_points": points.len(),
720            "points": pts,
721        }))
722        .unwrap_or_else(|_| "{}".to_string())
723    }
724
725    /// Render 2D point data as CSV.
726    pub fn points_csv(points: &[(f64, f64)]) -> String {
727        let mut out = String::from("x,y\n");
728        for (x, y) in points {
729            out.push_str(&format!("{},{}\n", x, y));
730        }
731        out
732    }
733
734    /// Wrap an SVG string in a minimal HTML document.
735    pub fn wrap_html(svg: &str) -> String {
736        format!(
737            "<!DOCTYPE html><html><head><meta charset=\"utf-8\"/></head><body>{svg}</body></html>",
738            svg = svg
739        )
740    }
741
742    // -------------------------------------------------------------------------
743    // Internal helpers
744    // -------------------------------------------------------------------------
745
746    /// Compute (x_min, x_max, y_min, y_max) for a slice of (x, y) points.
747    fn data_bounds(points: &[(f64, f64)]) -> (f64, f64, f64, f64) {
748        let mut x_min = f64::INFINITY;
749        let mut x_max = f64::NEG_INFINITY;
750        let mut y_min = f64::INFINITY;
751        let mut y_max = f64::NEG_INFINITY;
752        for (x, y) in points {
753            if *x < x_min {
754                x_min = *x;
755            }
756            if *x > x_max {
757                x_max = *x;
758            }
759            if *y < y_min {
760                y_min = *y;
761            }
762            if *y > y_max {
763                y_max = *y;
764            }
765        }
766        // Pad a little so points at edge are visible
767        let x_pad = if (x_max - x_min).abs() < f64::EPSILON { 1.0 } else { (x_max - x_min) * 0.05 };
768        let y_pad = if (y_max - y_min).abs() < f64::EPSILON { 1.0 } else { (y_max - y_min) * 0.05 };
769        (x_min - x_pad, x_max + x_pad, y_min - y_pad, y_max + y_pad)
770    }
771}
772
773// ============================================================================
774// Helper: write bytes to output_path when configured
775// ============================================================================
776
777fn write_output_path(path: &str, bytes: &[u8]) -> Result<()> {
778    let p = PathBuf::from(path);
779    std::fs::write(&p, bytes)
780        .map_err(|e| anyhow::anyhow!("Failed to write output to {}: {}", p.display(), e))
781}
782
783// ============================================================================
784// Built-in Plugins
785// ============================================================================
786
787/// Histogram visualization plugin
788struct HistogramPlugin;
789
790impl VisualizationPlugin for HistogramPlugin {
791    fn metadata(&self) -> PluginMetadata {
792        PluginMetadata {
793            name: "histogram".to_string(),
794            version: "1.0.0".to_string(),
795            description: "Generates histogram visualizations".to_string(),
796            author: "TrustformeRS".to_string(),
797            supported_inputs: vec!["Array1D".to_string(), "Tensor".to_string()],
798            supported_outputs: vec![
799                OutputFormat::Svg,
800                OutputFormat::Html,
801                OutputFormat::Text,
802                OutputFormat::Json,
803                OutputFormat::Csv,
804            ],
805            tags: vec!["distribution".to_string(), "statistics".to_string()],
806        }
807    }
808
809    fn config_schema(&self) -> serde_json::Value {
810        serde_json::json!({
811            "width": 800,
812            "height": 600,
813            "title": "Histogram",
814            "x_label": "Value",
815            "y_label": "Count",
816            "bins": 20
817        })
818    }
819
820    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult> {
821        let values = match data {
822            VisualizationData::Array1D(v) => v,
823            VisualizationData::Tensor { data, .. } => data,
824            _ => anyhow::bail!("Unsupported data type for histogram"),
825        };
826
827        if values.is_empty() {
828            anyhow::bail!("Histogram requires non-empty input data");
829        }
830
831        // Calculate histogram bins
832        let n_bins =
833            config.custom_params.get("bins").and_then(|v| v.as_u64()).unwrap_or(20) as usize;
834        let n_bins = n_bins.max(1);
835
836        let min = values.iter().copied().fold(f64::INFINITY, f64::min);
837        let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
838        let bin_width =
839            if (max - min).abs() < f64::EPSILON { 1.0 } else { (max - min) / n_bins as f64 };
840
841        let mut counts = vec![0usize; n_bins];
842        for &value in &values {
843            let bin_idx = ((value - min) / bin_width).floor() as usize;
844            let bin_idx = bin_idx.min(n_bins - 1);
845            counts[bin_idx] += 1;
846        }
847
848        let bins: Vec<(f64, f64, usize)> = counts
849            .iter()
850            .enumerate()
851            .map(|(i, &cnt)| {
852                let lo = min + i as f64 * bin_width;
853                let hi = lo + bin_width;
854                (lo, hi, cnt)
855            })
856            .collect();
857
858        let bytes = match config.output_format {
859            OutputFormat::Png => {
860                anyhow::bail!(
861                    "PNG output is not supported (requires binary encoder); use Svg instead"
862                )
863            },
864            OutputFormat::Svg => {
865                let svg = svg_render::histogram(&bins, &config);
866                svg.into_bytes()
867            },
868            OutputFormat::Html => {
869                let svg = svg_render::histogram(&bins, &config);
870                svg_render::wrap_html(&svg).into_bytes()
871            },
872            OutputFormat::Text => {
873                let output_text = format!(
874                    "Histogram (bins={}):\nMin={:.4}, Max={:.4}\nBin counts: {:?}",
875                    n_bins, min, max, counts
876                );
877                output_text.into_bytes()
878            },
879            OutputFormat::Json => {
880                svg_render::histogram_json(&bins, min, max, values.len()).into_bytes()
881            },
882            OutputFormat::Csv => svg_render::histogram_csv(&bins).into_bytes(),
883        };
884
885        let out_path_str = if let Some(ref path) = config.output_path {
886            write_output_path(path, &bytes)?;
887            Some(path.clone())
888        } else {
889            None
890        };
891
892        Ok(PluginResult {
893            success: true,
894            output_path: out_path_str,
895            output_data: Some(bytes),
896            metadata: {
897                let mut m = HashMap::new();
898                m.insert("bins".to_string(), n_bins.to_string());
899                m.insert("min".to_string(), min.to_string());
900                m.insert("max".to_string(), max.to_string());
901                m.insert("n_values".to_string(), values.len().to_string());
902                m
903            },
904            error: None,
905        })
906    }
907
908    fn validate(&self, data: &VisualizationData) -> bool {
909        matches!(
910            data,
911            VisualizationData::Array1D(_) | VisualizationData::Tensor { .. }
912        )
913    }
914}
915
916/// Heatmap visualization plugin
917struct HeatmapPlugin;
918
919impl VisualizationPlugin for HeatmapPlugin {
920    fn metadata(&self) -> PluginMetadata {
921        PluginMetadata {
922            name: "heatmap".to_string(),
923            version: "1.0.0".to_string(),
924            description: "Generates heatmap visualizations for 2D data".to_string(),
925            author: "TrustformeRS".to_string(),
926            supported_inputs: vec!["Array2D".to_string(), "Tensor".to_string()],
927            supported_outputs: vec![
928                OutputFormat::Svg,
929                OutputFormat::Html,
930                OutputFormat::Json,
931                OutputFormat::Csv,
932                OutputFormat::Text,
933            ],
934            tags: vec!["matrix".to_string(), "2d".to_string()],
935        }
936    }
937
938    fn config_schema(&self) -> serde_json::Value {
939        serde_json::json!({
940            "width": 800,
941            "height": 600,
942            "title": "Heatmap",
943            "x_label": "",
944            "y_label": ""
945        })
946    }
947
948    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult> {
949        let (rows, cols, flat_values) = match &data {
950            VisualizationData::Array2D(v) => {
951                let r = v.len();
952                let c = v.first().map(|row| row.len()).unwrap_or(0);
953                let flat: Vec<f64> = v.iter().flat_map(|row| row.iter().copied()).collect();
954                (r, c, flat)
955            },
956            VisualizationData::Tensor { shape, data } if shape.len() == 2 => {
957                (shape[0], shape[1], data.clone())
958            },
959            _ => anyhow::bail!("Heatmap requires 2D data"),
960        };
961
962        let bytes = match config.output_format {
963            OutputFormat::Png => {
964                anyhow::bail!(
965                    "PNG output is not supported (requires binary encoder); use Svg instead"
966                )
967            },
968            OutputFormat::Svg => {
969                let svg = svg_render::heatmap(rows, cols, &flat_values, &config);
970                svg.into_bytes()
971            },
972            OutputFormat::Html => {
973                let svg = svg_render::heatmap(rows, cols, &flat_values, &config);
974                svg_render::wrap_html(&svg).into_bytes()
975            },
976            OutputFormat::Text => format!("Heatmap {}x{}", rows, cols).into_bytes(),
977            OutputFormat::Json => {
978                let cells: Vec<serde_json::Value> = flat_values
979                    .iter()
980                    .enumerate()
981                    .map(|(i, v)| {
982                        serde_json::json!({
983                            "row": i / cols.max(1),
984                            "col": i % cols.max(1),
985                            "value": v,
986                        })
987                    })
988                    .collect();
989                serde_json::to_string_pretty(&serde_json::json!({
990                    "type": "heatmap",
991                    "rows": rows,
992                    "cols": cols,
993                    "cells": cells,
994                }))
995                .unwrap_or_else(|_| "{}".to_string())
996                .into_bytes()
997            },
998            OutputFormat::Csv => {
999                let mut out = String::from("row,col,value\n");
1000                for (i, v) in flat_values.iter().enumerate() {
1001                    let r = i / cols.max(1);
1002                    let c = i % cols.max(1);
1003                    out.push_str(&format!("{},{},{}\n", r, c, v));
1004                }
1005                out.into_bytes()
1006            },
1007        };
1008
1009        let out_path_str = if let Some(ref path) = config.output_path {
1010            write_output_path(path, &bytes)?;
1011            Some(path.clone())
1012        } else {
1013            None
1014        };
1015
1016        Ok(PluginResult {
1017            success: true,
1018            output_path: out_path_str,
1019            output_data: Some(bytes),
1020            metadata: {
1021                let mut m = HashMap::new();
1022                m.insert("rows".to_string(), rows.to_string());
1023                m.insert("cols".to_string(), cols.to_string());
1024                m
1025            },
1026            error: None,
1027        })
1028    }
1029
1030    fn validate(&self, data: &VisualizationData) -> bool {
1031        match data {
1032            VisualizationData::Array2D(_) => true,
1033            VisualizationData::Tensor { shape, .. } => shape.len() == 2,
1034            _ => false,
1035        }
1036    }
1037}
1038
1039/// Line plot visualization plugin
1040struct LinePlotPlugin;
1041
1042impl VisualizationPlugin for LinePlotPlugin {
1043    fn metadata(&self) -> PluginMetadata {
1044        PluginMetadata {
1045            name: "lineplot".to_string(),
1046            version: "1.0.0".to_string(),
1047            description: "Generates line plots for time series data".to_string(),
1048            author: "TrustformeRS".to_string(),
1049            supported_inputs: vec!["TimeSeries".to_string(), "Array1D".to_string()],
1050            supported_outputs: vec![
1051                OutputFormat::Svg,
1052                OutputFormat::Html,
1053                OutputFormat::Text,
1054                OutputFormat::Json,
1055                OutputFormat::Csv,
1056            ],
1057            tags: vec!["timeseries".to_string(), "trend".to_string()],
1058        }
1059    }
1060
1061    fn config_schema(&self) -> serde_json::Value {
1062        serde_json::json!({
1063            "width": 800,
1064            "height": 600,
1065            "title": "Line Plot",
1066            "x_label": "X",
1067            "y_label": "Y"
1068        })
1069    }
1070
1071    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult> {
1072        let points: Vec<(f64, f64)> = match &data {
1073            VisualizationData::TimeSeries {
1074                timestamps, values, ..
1075            } => timestamps.iter().zip(values.iter()).map(|(t, v)| (*t, *v)).collect(),
1076            VisualizationData::Array1D(v) => {
1077                v.iter().enumerate().map(|(i, val)| (i as f64, *val)).collect()
1078            },
1079            _ => anyhow::bail!("Line plot requires time series or 1D array data"),
1080        };
1081
1082        let n_points = points.len();
1083
1084        let bytes = match config.output_format {
1085            OutputFormat::Png => {
1086                anyhow::bail!(
1087                    "PNG output is not supported (requires binary encoder); use Svg instead"
1088                )
1089            },
1090            OutputFormat::Svg => {
1091                let svg = svg_render::line_plot(&points, &config);
1092                svg.into_bytes()
1093            },
1094            OutputFormat::Html => {
1095                let svg = svg_render::line_plot(&points, &config);
1096                svg_render::wrap_html(&svg).into_bytes()
1097            },
1098            OutputFormat::Text => format!("Line plot with {} points", n_points).into_bytes(),
1099            OutputFormat::Json => svg_render::points_json("lineplot", &points).into_bytes(),
1100            OutputFormat::Csv => svg_render::points_csv(&points).into_bytes(),
1101        };
1102
1103        let out_path_str = if let Some(ref path) = config.output_path {
1104            write_output_path(path, &bytes)?;
1105            Some(path.clone())
1106        } else {
1107            None
1108        };
1109
1110        Ok(PluginResult {
1111            success: true,
1112            output_path: out_path_str,
1113            output_data: Some(bytes),
1114            metadata: {
1115                let mut m = HashMap::new();
1116                m.insert("points".to_string(), n_points.to_string());
1117                m
1118            },
1119            error: None,
1120        })
1121    }
1122
1123    fn validate(&self, data: &VisualizationData) -> bool {
1124        matches!(
1125            data,
1126            VisualizationData::TimeSeries { .. } | VisualizationData::Array1D(_)
1127        )
1128    }
1129}
1130
1131/// Scatter plot visualization plugin
1132struct ScatterPlotPlugin;
1133
1134impl VisualizationPlugin for ScatterPlotPlugin {
1135    fn metadata(&self) -> PluginMetadata {
1136        PluginMetadata {
1137            name: "scatterplot".to_string(),
1138            version: "1.0.0".to_string(),
1139            description: "Generates scatter plots for 2D point data".to_string(),
1140            author: "TrustformeRS".to_string(),
1141            supported_inputs: vec!["Array2D".to_string()],
1142            supported_outputs: vec![
1143                OutputFormat::Svg,
1144                OutputFormat::Html,
1145                OutputFormat::Text,
1146                OutputFormat::Json,
1147                OutputFormat::Csv,
1148            ],
1149            tags: vec!["correlation".to_string(), "distribution".to_string()],
1150        }
1151    }
1152
1153    fn config_schema(&self) -> serde_json::Value {
1154        serde_json::json!({
1155            "width": 800,
1156            "height": 600,
1157            "title": "Scatter Plot",
1158            "x_label": "X",
1159            "y_label": "Y"
1160        })
1161    }
1162
1163    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult> {
1164        let points: Vec<(f64, f64)> = match &data {
1165            VisualizationData::Array2D(v) => v
1166                .iter()
1167                .filter_map(
1168                    |row| {
1169                        if row.len() >= 2 {
1170                            Some((row[0], row[1]))
1171                        } else {
1172                            None
1173                        }
1174                    },
1175                )
1176                .collect(),
1177            _ => anyhow::bail!("Scatter plot requires 2D array data (each row = [x, y])"),
1178        };
1179
1180        let n_points = points.len();
1181
1182        let bytes = match config.output_format {
1183            OutputFormat::Png => {
1184                anyhow::bail!(
1185                    "PNG output is not supported (requires binary encoder); use Svg instead"
1186                )
1187            },
1188            OutputFormat::Svg => {
1189                let svg = svg_render::scatter(&points, &config);
1190                svg.into_bytes()
1191            },
1192            OutputFormat::Html => {
1193                let svg = svg_render::scatter(&points, &config);
1194                svg_render::wrap_html(&svg).into_bytes()
1195            },
1196            OutputFormat::Text => format!("Scatter plot with {} points", n_points).into_bytes(),
1197            OutputFormat::Json => svg_render::points_json("scatterplot", &points).into_bytes(),
1198            OutputFormat::Csv => svg_render::points_csv(&points).into_bytes(),
1199        };
1200
1201        let out_path_str = if let Some(ref path) = config.output_path {
1202            write_output_path(path, &bytes)?;
1203            Some(path.clone())
1204        } else {
1205            None
1206        };
1207
1208        Ok(PluginResult {
1209            success: true,
1210            output_path: out_path_str,
1211            output_data: Some(bytes),
1212            metadata: {
1213                let mut m = HashMap::new();
1214                m.insert("points".to_string(), n_points.to_string());
1215                m
1216            },
1217            error: None,
1218        })
1219    }
1220
1221    fn validate(&self, data: &VisualizationData) -> bool {
1222        matches!(data, VisualizationData::Array2D(_))
1223    }
1224}
1225
1226#[cfg(test)]
1227mod tests {
1228    use super::*;
1229
1230    // -----------------------------------------------------------------------
1231    // Existing tests (updated for Svg default)
1232    // -----------------------------------------------------------------------
1233
1234    #[test]
1235    fn test_plugin_manager_creation() {
1236        let manager = PluginManager::new();
1237        let plugins = manager.list_plugins();
1238        assert!(!plugins.is_empty());
1239    }
1240
1241    #[test]
1242    fn test_histogram_plugin() {
1243        let manager = PluginManager::new();
1244        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1245        let config = PluginConfig::default(); // default is now Svg
1246
1247        let result = manager.execute("histogram", data, config).expect("operation failed in test");
1248
1249        assert!(result.success);
1250        assert!(result.output_data.is_some());
1251    }
1252
1253    #[test]
1254    fn test_plugin_validation() {
1255        let manager = PluginManager::new();
1256
1257        // Valid data for histogram
1258        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0]);
1259        let config = PluginConfig::default();
1260        assert!(manager.execute("histogram", data, config.clone()).is_ok());
1261
1262        // Invalid data for heatmap (needs 2D)
1263        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0]);
1264        assert!(manager.execute("heatmap", data, config).is_err());
1265    }
1266
1267    #[test]
1268    fn test_custom_plugin_registration() {
1269        let manager = PluginManager::new();
1270        let count_before = manager.list_plugins().len();
1271
1272        // Register histogram again (should replace)
1273        manager
1274            .register_plugin(Box::new(HistogramPlugin))
1275            .expect("operation failed in test");
1276
1277        let count_after = manager.list_plugins().len();
1278        assert_eq!(count_before, count_after);
1279    }
1280
1281    // -----------------------------------------------------------------------
1282    // New SVG rendering tests
1283    // -----------------------------------------------------------------------
1284
1285    fn svg_config() -> PluginConfig {
1286        PluginConfig {
1287            output_format: OutputFormat::Svg,
1288            ..PluginConfig::default()
1289        }
1290    }
1291
1292    #[test]
1293    fn test_histogram_svg_contains_rect() {
1294        let plugin = HistogramPlugin;
1295        let data = VisualizationData::Array1D(vec![1.0, 2.0, 2.5, 3.0, 4.0, 5.0, 5.5]);
1296        let result = plugin.execute(data, svg_config()).expect("histogram SVG render failed");
1297        assert!(result.success);
1298        let bytes = result.output_data.expect("no output data");
1299        let svg = String::from_utf8(bytes).expect("invalid UTF-8");
1300        assert!(
1301            svg.contains("<rect"),
1302            "SVG histogram should contain <rect elements; got: {}",
1303            &svg[..svg.len().min(300)]
1304        );
1305        assert!(svg.starts_with("<svg"), "output should start with <svg tag");
1306    }
1307
1308    #[test]
1309    fn test_heatmap_svg_contains_rect_cells() {
1310        let plugin = HeatmapPlugin;
1311        let data = VisualizationData::Array2D(vec![
1312            vec![1.0, 2.0, 3.0],
1313            vec![4.0, 5.0, 6.0],
1314            vec![7.0, 8.0, 9.0],
1315        ]);
1316        let result = plugin.execute(data, svg_config()).expect("heatmap SVG render failed");
1317        assert!(result.success);
1318        let bytes = result.output_data.expect("no output data");
1319        let svg = String::from_utf8(bytes).expect("invalid UTF-8");
1320        assert!(
1321            svg.contains("<rect"),
1322            "SVG heatmap should contain <rect cell elements"
1323        );
1324        assert!(svg.starts_with("<svg"));
1325    }
1326
1327    #[test]
1328    fn test_line_plot_svg_contains_polyline() {
1329        let plugin = LinePlotPlugin;
1330        let data = VisualizationData::TimeSeries {
1331            timestamps: vec![0.0, 1.0, 2.0, 3.0, 4.0],
1332            values: vec![0.1, 0.4, 0.9, 0.3, 0.7],
1333            labels: vec![],
1334        };
1335        let result = plugin.execute(data, svg_config()).expect("line plot SVG render failed");
1336        assert!(result.success);
1337        let bytes = result.output_data.expect("no output data");
1338        let svg = String::from_utf8(bytes).expect("invalid UTF-8");
1339        assert!(
1340            svg.contains("<polyline"),
1341            "SVG line plot should contain <polyline element"
1342        );
1343        assert!(svg.starts_with("<svg"));
1344    }
1345
1346    #[test]
1347    fn test_scatter_svg_contains_circle() {
1348        let plugin = ScatterPlotPlugin;
1349        let data = VisualizationData::Array2D(vec![
1350            vec![1.0, 2.0],
1351            vec![3.0, 4.0],
1352            vec![5.0, 6.0],
1353            vec![7.0, 8.0],
1354        ]);
1355        let result = plugin.execute(data, svg_config()).expect("scatter plot SVG render failed");
1356        assert!(result.success);
1357        let bytes = result.output_data.expect("no output data");
1358        let svg = String::from_utf8(bytes).expect("invalid UTF-8");
1359        assert!(
1360            svg.contains("<circle"),
1361            "SVG scatter plot should contain <circle elements"
1362        );
1363        assert!(svg.starts_with("<svg"));
1364    }
1365
1366    #[test]
1367    fn test_output_path_writes_file() {
1368        let plugin = HistogramPlugin;
1369        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1370
1371        let tmp_dir = std::env::temp_dir();
1372        let out_file = tmp_dir.join("test_histogram_visp.svg");
1373        let out_path_str = out_file.to_str().expect("temp path is not valid UTF-8").to_string();
1374
1375        let config = PluginConfig {
1376            output_format: OutputFormat::Svg,
1377            output_path: Some(out_path_str.clone()),
1378            ..PluginConfig::default()
1379        };
1380
1381        let result = plugin.execute(data, config).expect("histogram with output_path failed");
1382        assert!(result.success);
1383        assert_eq!(result.output_path.as_deref(), Some(out_path_str.as_str()));
1384
1385        let written = std::fs::read_to_string(&out_file).expect("output file not found on disk");
1386        assert!(
1387            written.contains("<rect"),
1388            "Written SVG file should contain <rect"
1389        );
1390
1391        // Cleanup
1392        std::fs::remove_file(&out_file).ok();
1393    }
1394
1395    #[test]
1396    fn test_png_returns_error() {
1397        let plugin = HistogramPlugin;
1398        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0]);
1399        let config = PluginConfig {
1400            output_format: OutputFormat::Png,
1401            ..PluginConfig::default()
1402        };
1403        let result = plugin.execute(data, config);
1404        assert!(result.is_err(), "PNG format should return an error, not Ok");
1405        let err = result.unwrap_err();
1406        let msg = err.to_string().to_lowercase();
1407        assert!(
1408            msg.contains("png") || msg.contains("not supported"),
1409            "Error message should mention PNG or not supported; got: {}",
1410            err
1411        );
1412    }
1413
1414    #[test]
1415    fn test_html_wraps_svg() {
1416        let plugin = ScatterPlotPlugin;
1417        let data = VisualizationData::Array2D(vec![vec![0.0, 1.0], vec![2.0, 3.0]]);
1418        let config = PluginConfig {
1419            output_format: OutputFormat::Html,
1420            ..PluginConfig::default()
1421        };
1422        let result = plugin.execute(data, config).expect("HTML render failed");
1423        let bytes = result.output_data.expect("no output data");
1424        let html = String::from_utf8(bytes).expect("invalid UTF-8");
1425        assert!(
1426            html.contains("<!DOCTYPE html>"),
1427            "HTML output should contain DOCTYPE"
1428        );
1429        assert!(html.contains("<svg"), "HTML output should embed SVG");
1430    }
1431
1432    #[test]
1433    fn test_histogram_json_output() {
1434        let plugin = HistogramPlugin;
1435        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0, 4.0]);
1436        let config = PluginConfig {
1437            output_format: OutputFormat::Json,
1438            ..PluginConfig::default()
1439        };
1440        let result = plugin.execute(data, config).expect("JSON render failed");
1441        let bytes = result.output_data.expect("no output data");
1442        let json_str = String::from_utf8(bytes).expect("invalid UTF-8");
1443        let parsed: serde_json::Value =
1444            serde_json::from_str(&json_str).expect("output is not valid JSON");
1445        assert_eq!(parsed["type"], "histogram");
1446        assert!(parsed["bins"].is_array());
1447    }
1448
1449    #[test]
1450    fn test_histogram_csv_output() {
1451        let plugin = HistogramPlugin;
1452        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0, 4.0]);
1453        let config = PluginConfig {
1454            output_format: OutputFormat::Csv,
1455            ..PluginConfig::default()
1456        };
1457        let result = plugin.execute(data, config).expect("CSV render failed");
1458        let bytes = result.output_data.expect("no output data");
1459        let csv_str = String::from_utf8(bytes).expect("invalid UTF-8");
1460        assert!(
1461            csv_str.starts_with("bin_left,bin_right,count"),
1462            "CSV should start with header"
1463        );
1464    }
1465
1466    #[test]
1467    fn test_config_schema_fields() {
1468        let hist_schema = HistogramPlugin.config_schema();
1469        assert!(hist_schema["width"].is_number());
1470        assert!(hist_schema["height"].is_number());
1471
1472        let heat_schema = HeatmapPlugin.config_schema();
1473        assert!(heat_schema["width"].is_number());
1474
1475        let line_schema = LinePlotPlugin.config_schema();
1476        assert!(line_schema["x_label"].is_string());
1477
1478        let scatter_schema = ScatterPlotPlugin.config_schema();
1479        assert!(scatter_schema["title"].is_string());
1480    }
1481}