Skip to main content

trustformers_debug/
visualization_plugins.rs

1//! Custom Visualization Plugin System
2//!
3//! This module provides an extensible plugin system for custom visualizations,
4//! allowing users to create and register their own visualization tools.
5
6use anyhow::Result;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Visualization data that can be passed to plugins
13#[derive(Debug, Clone)]
14pub enum VisualizationData {
15    /// 1D array data
16    Array1D(Vec<f64>),
17    /// 2D array data
18    Array2D(Vec<Vec<f64>>),
19    /// Tensor data with shape information
20    Tensor { data: Vec<f64>, shape: Vec<usize> },
21    /// Key-value pairs (for metadata, metrics, etc.)
22    KeyValue(HashMap<String, String>),
23    /// Time series data
24    TimeSeries {
25        timestamps: Vec<f64>,
26        values: Vec<f64>,
27        labels: Vec<String>,
28    },
29    /// Custom JSON data
30    Json(serde_json::Value),
31}
32
33/// Output format for visualizations
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35pub enum OutputFormat {
36    /// PNG image
37    Png,
38    /// SVG vector graphics
39    Svg,
40    /// HTML interactive visualization
41    Html,
42    /// Plain text/ASCII
43    Text,
44    /// JSON data
45    Json,
46    /// CSV data
47    Csv,
48}
49
50/// Plugin capabilities and metadata
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PluginMetadata {
53    /// Plugin name
54    pub name: String,
55    /// Plugin version
56    pub version: String,
57    /// Plugin description
58    pub description: String,
59    /// Author
60    pub author: String,
61    /// Supported input data types
62    pub supported_inputs: Vec<String>,
63    /// Supported output formats
64    pub supported_outputs: Vec<OutputFormat>,
65    /// Tags/categories
66    pub tags: Vec<String>,
67}
68
69/// Configuration for visualization plugins
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct PluginConfig {
72    /// Output format
73    pub output_format: OutputFormat,
74    /// Output path (file or directory)
75    pub output_path: Option<String>,
76    /// Width in pixels (for image outputs)
77    pub width: usize,
78    /// Height in pixels (for image outputs)
79    pub height: usize,
80    /// Color scheme
81    pub color_scheme: String,
82    /// Additional custom parameters
83    pub custom_params: HashMap<String, serde_json::Value>,
84}
85
86impl Default for PluginConfig {
87    fn default() -> Self {
88        Self {
89            output_format: OutputFormat::Png,
90            output_path: None,
91            width: 800,
92            height: 600,
93            color_scheme: "viridis".to_string(),
94            custom_params: HashMap::new(),
95        }
96    }
97}
98
99/// Result of plugin execution
100#[derive(Debug)]
101pub struct PluginResult {
102    /// Success status
103    pub success: bool,
104    /// Output file path (if file was created)
105    pub output_path: Option<String>,
106    /// Raw output data (if applicable)
107    pub output_data: Option<Vec<u8>>,
108    /// Metadata about the visualization
109    pub metadata: HashMap<String, String>,
110    /// Error message (if failed)
111    pub error: Option<String>,
112}
113
114/// Trait for visualization plugins
115///
116/// Implement this trait to create custom visualization plugins.
117pub trait VisualizationPlugin: Send + Sync {
118    /// Get plugin metadata
119    fn metadata(&self) -> PluginMetadata;
120
121    /// Execute the visualization
122    ///
123    /// # Arguments
124    /// * `data` - Input data to visualize
125    /// * `config` - Configuration for the visualization
126    ///
127    /// # Returns
128    /// Result containing the plugin output
129    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult>;
130
131    /// Validate input data
132    ///
133    /// # Arguments
134    /// * `data` - Data to validate
135    ///
136    /// # Returns
137    /// True if data is valid for this plugin
138    fn validate(&self, data: &VisualizationData) -> bool {
139        // Default implementation: accept all data
140        let _ = data;
141        true
142    }
143
144    /// Get configuration schema (for UI generation)
145    fn config_schema(&self) -> serde_json::Value {
146        serde_json::json!({
147            "type": "object",
148            "properties": {}
149        })
150    }
151}
152
153/// Plugin manager for registering and executing visualization plugins
154pub struct PluginManager {
155    /// Registered plugins
156    plugins: Arc<RwLock<HashMap<String, Box<dyn VisualizationPlugin>>>>,
157    /// Plugin execution history
158    history: Arc<RwLock<Vec<PluginExecution>>>,
159}
160
161/// Record of plugin execution
162#[derive(Debug, Clone)]
163struct PluginExecution {
164    plugin_name: String,
165    timestamp: std::time::SystemTime,
166    success: bool,
167    duration_ms: u128,
168}
169
170impl Default for PluginManager {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176impl PluginManager {
177    /// Create a new plugin manager
178    pub fn new() -> Self {
179        let manager = Self {
180            plugins: Arc::new(RwLock::new(HashMap::new())),
181            history: Arc::new(RwLock::new(Vec::new())),
182        };
183
184        // Register built-in plugins
185        manager.register_builtin_plugins();
186
187        manager
188    }
189
190    /// Register built-in plugins
191    fn register_builtin_plugins(&self) {
192        // Register histogram plugin
193        self.register_plugin(Box::new(HistogramPlugin)).ok();
194
195        // Register heatmap plugin
196        self.register_plugin(Box::new(HeatmapPlugin)).ok();
197
198        // Register line plot plugin
199        self.register_plugin(Box::new(LinePlotPlugin)).ok();
200
201        // Register scatter plot plugin
202        self.register_plugin(Box::new(ScatterPlotPlugin)).ok();
203    }
204
205    /// Register a new plugin
206    ///
207    /// # Arguments
208    /// * `plugin` - Plugin to register
209    pub fn register_plugin(&self, plugin: Box<dyn VisualizationPlugin>) -> Result<()> {
210        let name = plugin.metadata().name.clone();
211
212        self.plugins.write().insert(name.clone(), plugin);
213
214        tracing::info!(plugin_name = %name, "Registered visualization plugin");
215
216        Ok(())
217    }
218
219    /// Unregister a plugin
220    ///
221    /// # Arguments
222    /// * `name` - Name of plugin to unregister
223    pub fn unregister_plugin(&self, name: &str) -> Result<()> {
224        self.plugins.write().remove(name);
225
226        tracing::info!(plugin_name = %name, "Unregistered visualization plugin");
227
228        Ok(())
229    }
230
231    /// Get list of registered plugins
232    pub fn list_plugins(&self) -> Vec<PluginMetadata> {
233        self.plugins.read().values().map(|p| p.metadata()).collect()
234    }
235
236    /// Execute a plugin
237    ///
238    /// # Arguments
239    /// * `plugin_name` - Name of plugin to execute
240    /// * `data` - Input data
241    /// * `config` - Configuration
242    pub fn execute(
243        &self,
244        plugin_name: &str,
245        data: VisualizationData,
246        config: PluginConfig,
247    ) -> Result<PluginResult> {
248        let start_time = std::time::Instant::now();
249
250        let result = {
251            let plugins = self.plugins.read();
252            let plugin = plugins
253                .get(plugin_name)
254                .ok_or_else(|| anyhow::anyhow!("Plugin not found: {}", plugin_name))?;
255
256            // Validate data
257            if !plugin.validate(&data) {
258                anyhow::bail!("Invalid data for plugin: {}", plugin_name);
259            }
260
261            plugin.execute(data, config)?
262        };
263
264        let duration = start_time.elapsed().as_millis();
265
266        // Record execution
267        self.history.write().push(PluginExecution {
268            plugin_name: plugin_name.to_string(),
269            timestamp: std::time::SystemTime::now(),
270            success: result.success,
271            duration_ms: duration,
272        });
273
274        Ok(result)
275    }
276
277    /// Get plugin by name
278    pub fn get_plugin(&self, name: &str) -> Option<PluginMetadata> {
279        self.plugins.read().get(name).map(|p| p.metadata())
280    }
281
282    /// Get execution history
283    pub fn get_history(&self) -> Vec<String> {
284        self.history
285            .read()
286            .iter()
287            .map(|e| {
288                format!(
289                    "{}: {} ({}ms) - {}",
290                    e.timestamp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
291                    e.plugin_name,
292                    e.duration_ms,
293                    if e.success { "success" } else { "failed" }
294                )
295            })
296            .collect()
297    }
298}
299
300// ============================================================================
301// Built-in Plugins
302// ============================================================================
303
304/// Histogram visualization plugin
305struct HistogramPlugin;
306
307impl VisualizationPlugin for HistogramPlugin {
308    fn metadata(&self) -> PluginMetadata {
309        PluginMetadata {
310            name: "histogram".to_string(),
311            version: "1.0.0".to_string(),
312            description: "Generates histogram visualizations".to_string(),
313            author: "TrustformeRS".to_string(),
314            supported_inputs: vec!["Array1D".to_string(), "Tensor".to_string()],
315            supported_outputs: vec![OutputFormat::Png, OutputFormat::Svg, OutputFormat::Text],
316            tags: vec!["distribution".to_string(), "statistics".to_string()],
317        }
318    }
319
320    fn execute(&self, data: VisualizationData, config: PluginConfig) -> Result<PluginResult> {
321        let values = match data {
322            VisualizationData::Array1D(v) => v,
323            VisualizationData::Tensor { data, .. } => data,
324            _ => anyhow::bail!("Unsupported data type for histogram"),
325        };
326
327        // Calculate histogram bins
328        let bins = config.custom_params.get("bins").and_then(|v| v.as_u64()).unwrap_or(20) as usize;
329
330        let min = values.iter().copied().fold(f64::INFINITY, f64::min);
331        let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
332        let bin_width = (max - min) / bins as f64;
333
334        let mut counts = vec![0; bins];
335        for &value in &values {
336            let bin_idx = ((value - min) / bin_width).floor() as usize;
337            let bin_idx = bin_idx.min(bins - 1);
338            counts[bin_idx] += 1;
339        }
340
341        // Generate output
342        let output_text = format!(
343            "Histogram (bins={}):\nMin={:.4}, Max={:.4}\nBin counts: {:?}",
344            bins, min, max, counts
345        );
346
347        Ok(PluginResult {
348            success: true,
349            output_path: None,
350            output_data: Some(output_text.into_bytes()),
351            metadata: {
352                let mut m = HashMap::new();
353                m.insert("bins".to_string(), bins.to_string());
354                m.insert("min".to_string(), min.to_string());
355                m.insert("max".to_string(), max.to_string());
356                m
357            },
358            error: None,
359        })
360    }
361
362    fn validate(&self, data: &VisualizationData) -> bool {
363        matches!(
364            data,
365            VisualizationData::Array1D(_) | VisualizationData::Tensor { .. }
366        )
367    }
368}
369
370/// Heatmap visualization plugin
371struct HeatmapPlugin;
372
373impl VisualizationPlugin for HeatmapPlugin {
374    fn metadata(&self) -> PluginMetadata {
375        PluginMetadata {
376            name: "heatmap".to_string(),
377            version: "1.0.0".to_string(),
378            description: "Generates heatmap visualizations for 2D data".to_string(),
379            author: "TrustformeRS".to_string(),
380            supported_inputs: vec!["Array2D".to_string(), "Tensor".to_string()],
381            supported_outputs: vec![OutputFormat::Png, OutputFormat::Html],
382            tags: vec!["matrix".to_string(), "2d".to_string()],
383        }
384    }
385
386    fn execute(&self, data: VisualizationData, _config: PluginConfig) -> Result<PluginResult> {
387        let (rows, cols) = match &data {
388            VisualizationData::Array2D(v) => (v.len(), v.first().map(|r| r.len()).unwrap_or(0)),
389            VisualizationData::Tensor { shape, .. } if shape.len() == 2 => (shape[0], shape[1]),
390            _ => anyhow::bail!("Heatmap requires 2D data"),
391        };
392
393        Ok(PluginResult {
394            success: true,
395            output_path: None,
396            output_data: Some(format!("Heatmap {}x{}", rows, cols).into_bytes()),
397            metadata: {
398                let mut m = HashMap::new();
399                m.insert("rows".to_string(), rows.to_string());
400                m.insert("cols".to_string(), cols.to_string());
401                m
402            },
403            error: None,
404        })
405    }
406
407    fn validate(&self, data: &VisualizationData) -> bool {
408        match data {
409            VisualizationData::Array2D(_) => true,
410            VisualizationData::Tensor { shape, .. } => shape.len() == 2,
411            _ => false,
412        }
413    }
414}
415
416/// Line plot visualization plugin
417struct LinePlotPlugin;
418
419impl VisualizationPlugin for LinePlotPlugin {
420    fn metadata(&self) -> PluginMetadata {
421        PluginMetadata {
422            name: "lineplot".to_string(),
423            version: "1.0.0".to_string(),
424            description: "Generates line plots for time series data".to_string(),
425            author: "TrustformeRS".to_string(),
426            supported_inputs: vec!["TimeSeries".to_string(), "Array1D".to_string()],
427            supported_outputs: vec![OutputFormat::Png, OutputFormat::Svg],
428            tags: vec!["timeseries".to_string(), "trend".to_string()],
429        }
430    }
431
432    fn execute(&self, data: VisualizationData, _config: PluginConfig) -> Result<PluginResult> {
433        let points = match &data {
434            VisualizationData::TimeSeries { values, .. } => values.len(),
435            VisualizationData::Array1D(v) => v.len(),
436            _ => anyhow::bail!("Line plot requires time series or 1D array data"),
437        };
438
439        Ok(PluginResult {
440            success: true,
441            output_path: None,
442            output_data: Some(format!("Line plot with {} points", points).into_bytes()),
443            metadata: {
444                let mut m = HashMap::new();
445                m.insert("points".to_string(), points.to_string());
446                m
447            },
448            error: None,
449        })
450    }
451}
452
453/// Scatter plot visualization plugin
454struct ScatterPlotPlugin;
455
456impl VisualizationPlugin for ScatterPlotPlugin {
457    fn metadata(&self) -> PluginMetadata {
458        PluginMetadata {
459            name: "scatterplot".to_string(),
460            version: "1.0.0".to_string(),
461            description: "Generates scatter plots for 2D point data".to_string(),
462            author: "TrustformeRS".to_string(),
463            supported_inputs: vec!["Array2D".to_string()],
464            supported_outputs: vec![OutputFormat::Png, OutputFormat::Html],
465            tags: vec!["correlation".to_string(), "distribution".to_string()],
466        }
467    }
468
469    fn execute(&self, data: VisualizationData, _config: PluginConfig) -> Result<PluginResult> {
470        let points = match &data {
471            VisualizationData::Array2D(v) => v.len(),
472            _ => anyhow::bail!("Scatter plot requires 2D array data"),
473        };
474
475        Ok(PluginResult {
476            success: true,
477            output_path: None,
478            output_data: Some(format!("Scatter plot with {} points", points).into_bytes()),
479            metadata: {
480                let mut m = HashMap::new();
481                m.insert("points".to_string(), points.to_string());
482                m
483            },
484            error: None,
485        })
486    }
487
488    fn validate(&self, data: &VisualizationData) -> bool {
489        matches!(data, VisualizationData::Array2D(_))
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_plugin_manager_creation() {
499        let manager = PluginManager::new();
500        let plugins = manager.list_plugins();
501
502        // Should have built-in plugins
503        assert!(!plugins.is_empty());
504    }
505
506    #[test]
507    fn test_histogram_plugin() {
508        let manager = PluginManager::new();
509
510        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
511        let config = PluginConfig::default();
512
513        let result = manager.execute("histogram", data, config).unwrap();
514
515        assert!(result.success);
516        assert!(result.output_data.is_some());
517    }
518
519    #[test]
520    fn test_plugin_validation() {
521        let manager = PluginManager::new();
522
523        // Valid data for histogram
524        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0]);
525        let config = PluginConfig::default();
526
527        assert!(manager.execute("histogram", data, config.clone()).is_ok());
528
529        // Invalid data for heatmap (needs 2D)
530        let data = VisualizationData::Array1D(vec![1.0, 2.0, 3.0]);
531
532        assert!(manager.execute("heatmap", data, config).is_err());
533    }
534
535    #[test]
536    fn test_custom_plugin_registration() {
537        let manager = PluginManager::new();
538
539        // Count plugins before
540        let count_before = manager.list_plugins().len();
541
542        // Register histogram again (should replace)
543        manager.register_plugin(Box::new(HistogramPlugin)).unwrap();
544
545        let count_after = manager.list_plugins().len();
546
547        // Should be same count (replacement)
548        assert_eq!(count_before, count_after);
549    }
550}