scirs2_neural/visualization/
mod.rs

1//! Visualization tools for neural networks
2//!
3//! This module provides comprehensive visualization capabilities including:
4//! - Network architecture visualization with interactive graphs
5//! - Training curves and metrics plotting
6//! - Layer activation maps and feature visualization
7//! - Attention mechanisms visualization
8//! - Interactive dashboards and real-time monitoring
9//!
10//! # Module Organization
11//!
12//! - [`config`] - Configuration types and settings for all visualization aspects
13//! - [`network`] - Network architecture visualization and layout algorithms
14//! - [`training`] - Training metrics, curves, and performance monitoring
15//! - [`activations`] - Layer activation analysis and feature map visualization
16//! - [`attention`] - Attention mechanism visualization and analysis
17//!
18//! # Basic Usage
19//!
20//! ```rust
21//! use scirs2_neural::visualization::{VisualizationConfig, NetworkVisualizer};
22//! use scirs2_neural::models::Sequential;
23//! use scirs2_neural::layers::Dense;
24//! use rand::SeedableRng;
25//!
26//! // Create a simple model
27//! let mut rng = rand::rngs::StdRng::seed_from_u64(42);
28//! let mut model = Sequential::<f32>::new();
29//! model.add_layer(Dense::new(784, 128, Some("relu"), &mut rng).unwrap());
30//! model.add_layer(Dense::new(128, 10, Some("softmax"), &mut rng).unwrap());
31//!
32//! // Configure visualization
33//! let config = VisualizationConfig::default();
34//!
35//! // Create network visualizer
36//! let mut visualizer = NetworkVisualizer::new(model, config);
37//!
38//! // Generate architecture visualization
39//! // Note: This is a placeholder - actual implementation coming soon
40//! // let output_path = visualizer.visualize_architecture()?;
41//! // println!("Network visualization saved to: {:?}", output_path);
42//! # Ok::<(), Box<dyn std::error::Error>>(())
43//! ```
44
45// Module declarations
46pub mod activations;
47pub mod attention;
48pub mod config;
49pub mod network;
50pub mod training;
51
52// Re-export main configuration types
53pub use config::{
54    ColorPalette, CustomTheme, DownsamplingStrategy, FontConfig, GridConfig, ImageFormat,
55    InteractiveConfig, LayoutConfig, Margins, PerformanceConfig, StyleConfig, Theme,
56    VisualizationConfig,
57};
58
59// Re-export network visualization types and components
60pub use network::{
61    ArrowStyle, BoundingBox, Connection, ConnectionType, ConnectionVisualProps, DataFlowInfo,
62    LayerIOInfo, LayerInfo, LayerPosition, LayerVisualProps, LayoutAlgorithm, LineStyle,
63    NetworkLayout, NetworkVisualizer, Point2D, Size2D, ThroughputInfo,
64};
65
66// Re-export training visualization types and components
67pub use training::{
68    AxisConfig, AxisScale, LineStyleConfig, MarkerConfig, MarkerShape, PlotConfig, PlotType,
69    SeriesConfig, SystemMetrics, TickConfig, TickFormat, TrainingMetrics, TrainingVisualizer,
70    UpdateMode,
71};
72
73// Re-export activation visualization types and components
74pub use activations::{
75    ActivationHistogram, ActivationNormalization, ActivationStatistics,
76    ActivationVisualizationOptions, ActivationVisualizationType, ActivationVisualizer,
77    ChannelAggregation, Colormap, FeatureMapInfo,
78};
79
80// Re-export attention visualization types and components
81pub use attention::{
82    AttentionData, AttentionStatistics, AttentionVisualizationOptions, AttentionVisualizationType,
83    AttentionVisualizer, CompressionSettings, DataFormat, ExportFormat, ExportOptions,
84    ExportQuality, HeadAggregation, HeadInfo, HeadSelection, HighlightConfig, HighlightStyle,
85    Resolution, VideoFormat,
86};
87
88// Convenience type aliases for common use cases
89/// Convenient type alias for network visualization
90pub type NetworkViz<F> = NetworkVisualizer<F>;
91/// Convenient type alias for training visualization  
92pub type TrainingViz<F> = TrainingVisualizer<F>;
93/// Convenient type alias for activation visualization
94pub type ActivationViz<F> = ActivationVisualizer<F>;
95/// Convenient type alias for attention visualization
96pub type AttentionViz<F> = AttentionVisualizer<F>;
97
98/// Combined visualization suite for comprehensive neural network analysis
99pub struct VisualizationSuite<F>
100where
101    F: num_traits::Float
102        + std::fmt::Debug
103        + ndarray::ScalarOperand
104        + 'static
105        + num_traits::FromPrimitive
106        + Send
107        + Sync
108        + serde::Serialize,
109{
110    /// Network architecture visualizer
111    pub network: NetworkVisualizer<F>,
112    /// Training metrics visualizer
113    pub training: TrainingVisualizer<F>,
114    /// Activation visualizer
115    pub activation: ActivationVisualizer<F>,
116    /// Attention visualizer
117    pub attention: AttentionVisualizer<F>,
118    /// Shared configuration
119    config: VisualizationConfig,
120}
121
122impl<F> VisualizationSuite<F>
123where
124    F: num_traits::Float
125        + std::fmt::Debug
126        + ndarray::ScalarOperand
127        + 'static
128        + num_traits::FromPrimitive
129        + Send
130        + Sync
131        + serde::Serialize,
132{
133    /// Create a new comprehensive visualization suite
134    pub fn new(
135        model: crate::models::sequential::Sequential<F>,
136        config: VisualizationConfig,
137    ) -> Self {
138        let training = TrainingVisualizer::new(config.clone());
139        // Create default/empty models for visualizers that need them but don't use them actively
140        let activation = ActivationVisualizer::new(
141            crate::models::sequential::Sequential::default(),
142            config.clone(),
143        );
144        let attention = AttentionVisualizer::new(
145            crate::models::sequential::Sequential::default(),
146            config.clone(),
147        );
148        let network = NetworkVisualizer::new(model, config.clone());
149
150        Self {
151            network,
152            training,
153            activation,
154            attention,
155            config,
156        }
157    }
158
159    /// Update configuration for all visualizers
160    pub fn update_config(&mut self, config: VisualizationConfig) {
161        self.config = config.clone();
162        self.network.update_config(config.clone());
163        self.training.update_config(config.clone());
164        self.activation.update_config(config.clone());
165        self.attention.update_config(config);
166    }
167
168    /// Get the current configuration
169    pub fn get_config(&self) -> &VisualizationConfig {
170        &self.config
171    }
172
173    /// Clear all caches across visualizers
174    pub fn clear_all_caches(&mut self) {
175        self.network.clear_cache();
176        self.training.clear_history();
177        self.activation.clear_cache();
178        self.attention.clear_cache();
179    }
180}
181
182/// Builder pattern for creating visualization configurations
183pub struct VisualizationConfigBuilder {
184    config: VisualizationConfig,
185}
186
187impl VisualizationConfigBuilder {
188    /// Create a new configuration builder
189    pub fn new() -> Self {
190        Self {
191            config: VisualizationConfig::default(),
192        }
193    }
194
195    /// Set the output directory
196    pub fn output_dir<P: Into<std::path::PathBuf>>(mut self, path: P) -> Self {
197        self.config.output_dir = path.into();
198        self
199    }
200
201    /// Set the image format
202    pub fn image_format(mut self, format: ImageFormat) -> Self {
203        self.config.image_format = format;
204        self
205    }
206
207    /// Set the color palette
208    pub fn color_palette(mut self, palette: ColorPalette) -> Self {
209        self.config.style.color_palette = palette;
210        self
211    }
212
213    /// Set the theme
214    pub fn theme(mut self, theme: Theme) -> Self {
215        self.config.style.theme = theme;
216        self
217    }
218
219    /// Enable or disable interactive features
220    pub fn interactive(mut self, enable: bool) -> Self {
221        self.config.interactive.enable_interaction = enable;
222        self
223    }
224
225    /// Set canvas dimensions
226    pub fn canvas_size(mut self, width: u32, height: u32) -> Self {
227        self.config.style.layout.width = width;
228        self.config.style.layout.height = height;
229        self
230    }
231
232    /// Set performance settings
233    pub fn max_points(mut self, max_points: usize) -> Self {
234        self.config.performance.max_points_per_plot = max_points;
235        self
236    }
237
238    /// Enable or disable downsampling
239    pub fn downsampling(mut self, strategy: DownsamplingStrategy) -> Self {
240        self.config.performance.enable_downsampling = true;
241        self.config.performance.downsampling_strategy = strategy;
242        self
243    }
244
245    /// Build the configuration
246    pub fn build(self) -> VisualizationConfig {
247        self.config
248    }
249}
250
251impl Default for VisualizationConfigBuilder {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257/// Utility functions for visualization
258pub mod utils {
259    use super::*;
260
261    /// Create a quick visualization configuration for prototyping
262    pub fn quick_config() -> VisualizationConfig {
263        VisualizationConfigBuilder::new()
264            .canvas_size(800, 600)
265            .theme(Theme::Light)
266            .color_palette(ColorPalette::Default)
267            .interactive(false)
268            .build()
269    }
270
271    /// Create a high-quality visualization configuration for publication
272    pub fn publication_config() -> VisualizationConfig {
273        VisualizationConfigBuilder::new()
274            .canvas_size(1920, 1080)
275            .image_format(ImageFormat::PDF)
276            .theme(Theme::Light)
277            .color_palette(ColorPalette::ColorblindFriendly)
278            .interactive(false)
279            .build()
280    }
281
282    /// Create an interactive visualization configuration for dashboards
283    pub fn dashboard_config() -> VisualizationConfig {
284        VisualizationConfigBuilder::new()
285            .canvas_size(1200, 800)
286            .image_format(ImageFormat::HTML)
287            .theme(Theme::Dark)
288            .color_palette(ColorPalette::HighContrast)
289            .interactive(true)
290            .max_points(50000)
291            .downsampling(DownsamplingStrategy::LTTB)
292            .build()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::layers::Dense;
300    use rand::SeedableRng;
301
302    #[test]
303    fn test_visualization_suite_creation() {
304        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
305        let mut model = crate::models::sequential::Sequential::<f32>::new();
306        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).unwrap());
307
308        let config = VisualizationConfig::default();
309        let _suite = VisualizationSuite::new(model, config);
310
311        // Test passes if no panic occurs during creation
312    }
313
314    #[test]
315    fn test_config_builder() {
316        let config = VisualizationConfigBuilder::new()
317            .canvas_size(1920, 1080)
318            .theme(Theme::Dark)
319            .color_palette(ColorPalette::HighContrast)
320            .interactive(true)
321            .build();
322
323        assert_eq!(config.style.layout.width, 1920);
324        assert_eq!(config.style.layout.height, 1080);
325        assert_eq!(config.style.theme, Theme::Dark);
326        assert_eq!(config.style.color_palette, ColorPalette::HighContrast);
327        assert!(config.interactive.enable_interaction);
328    }
329
330    #[test]
331    fn test_utility_configs() {
332        let quick = utils::quick_config();
333        assert_eq!(quick.style.layout.width, 800);
334        assert_eq!(quick.style.layout.height, 600);
335        assert_eq!(quick.style.theme, Theme::Light);
336        assert!(!quick.interactive.enable_interaction);
337
338        let publication = utils::publication_config();
339        assert_eq!(publication.image_format, ImageFormat::PDF);
340        assert_eq!(
341            publication.style.color_palette,
342            ColorPalette::ColorblindFriendly
343        );
344
345        let dashboard = utils::dashboard_config();
346        assert_eq!(dashboard.image_format, ImageFormat::HTML);
347        assert_eq!(dashboard.style.theme, Theme::Dark);
348        assert!(dashboard.interactive.enable_interaction);
349    }
350
351    #[test]
352    fn test_type_aliases() {
353        // Test that type aliases compile correctly
354        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
355        let mut model = crate::models::sequential::Sequential::<f32>::new();
356        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).unwrap());
357
358        let config = VisualizationConfig::default();
359
360        let _network_viz: NetworkViz<f32> = NetworkVisualizer::new(model.clone(), config.clone());
361        let _training_viz: TrainingViz<f32> = TrainingVisualizer::new(config.clone());
362        let _activation_viz: ActivationViz<f32> =
363            ActivationVisualizer::new(model.clone(), config.clone());
364        let _attention_viz: AttentionViz<f32> = AttentionVisualizer::new(model, config);
365    }
366
367    #[test]
368    fn test_suite_operations() {
369        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
370        let mut model = crate::models::sequential::Sequential::<f32>::new();
371        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).unwrap());
372
373        let config = VisualizationConfig::default();
374        let mut suite = VisualizationSuite::new(model, config.clone());
375
376        assert_eq!(
377            suite.get_config().style.layout.width,
378            config.style.layout.width
379        );
380
381        // Test cache clearing
382        suite.clear_all_caches();
383
384        // Test config update
385        let new_config = VisualizationConfigBuilder::new()
386            .canvas_size(1024, 768)
387            .build();
388        suite.update_config(new_config);
389        assert_eq!(suite.get_config().style.layout.width, 1024);
390    }
391
392    #[test]
393    fn test_module_integration() {
394        // Test that all modules are properly accessible
395        use super::activations::*;
396        use super::attention::*;
397        use super::config::*;
398        use super::network::*;
399        use super::training::*;
400
401        // Test default configurations
402        let _viz_config = VisualizationConfig::default();
403        let _plot_config = PlotConfig::default();
404        let _activation_options = ActivationVisualizationOptions::default();
405        let _attention_options = AttentionVisualizationOptions::default();
406        let _export_options = ExportOptions::default();
407
408        // Test enums
409        let _image_format = ImageFormat::SVG;
410        let _layout_algo = LayoutAlgorithm::Hierarchical;
411        let _plot_type = PlotType::Line;
412        let _colormap = Colormap::Viridis;
413        let _attention_type = AttentionVisualizationType::Heatmap;
414    }
415}