scirs2_neural/visualization/
mod.rs1pub mod activations;
47pub mod attention;
48pub mod config;
49pub mod network;
50pub mod training;
51
52pub use config::{
54 ColorPalette, CustomTheme, DownsamplingStrategy, FontConfig, GridConfig, ImageFormat,
55 InteractiveConfig, LayoutConfig, Margins, PerformanceConfig, StyleConfig, Theme,
56 VisualizationConfig,
57};
58
59pub use network::{
61 ArrowStyle, BoundingBox, Connection, ConnectionType, ConnectionVisualProps, DataFlowInfo,
62 LayerIOInfo, LayerInfo, LayerPosition, LayerVisualProps, LayoutAlgorithm, LineStyle,
63 NetworkLayout, NetworkVisualizer, Point2D, Size2D, ThroughputInfo,
64};
65
66pub use training::{
68 AxisConfig, AxisScale, LineStyleConfig, MarkerConfig, MarkerShape, PlotConfig, PlotType,
69 SeriesConfig, SystemMetrics, TickConfig, TickFormat, TrainingMetrics, TrainingVisualizer,
70 UpdateMode,
71};
72
73pub use activations::{
75 ActivationHistogram, ActivationNormalization, ActivationStatistics,
76 ActivationVisualizationOptions, ActivationVisualizationType, ActivationVisualizer,
77 ChannelAggregation, Colormap, FeatureMapInfo,
78};
79
80pub use attention::{
82 AttentionData, AttentionStatistics, AttentionVisualizationOptions, AttentionVisualizationType,
83 AttentionVisualizer, CompressionSettings, DataFormat, ExportFormat, ExportOptions,
84 ExportQuality, HeadAggregation, HeadInfo, HeadSelection, HighlightConfig, HighlightStyle,
85 Resolution, VideoFormat,
86};
87
88pub type NetworkViz<F> = NetworkVisualizer<F>;
91pub type TrainingViz<F> = TrainingVisualizer<F>;
93pub type ActivationViz<F> = ActivationVisualizer<F>;
95pub type AttentionViz<F> = AttentionVisualizer<F>;
97
98pub struct VisualizationSuite<F>
100where
101 F: num_traits::Float
102 + std::fmt::Debug
103 + ndarray::ScalarOperand
104 + 'static
105 + num_traits::FromPrimitive
106 + Send
107 + Sync
108 + serde::Serialize,
109{
110 pub network: NetworkVisualizer<F>,
112 pub training: TrainingVisualizer<F>,
114 pub activation: ActivationVisualizer<F>,
116 pub attention: AttentionVisualizer<F>,
118 config: VisualizationConfig,
120}
121
122impl<F> VisualizationSuite<F>
123where
124 F: num_traits::Float
125 + std::fmt::Debug
126 + ndarray::ScalarOperand
127 + 'static
128 + num_traits::FromPrimitive
129 + Send
130 + Sync
131 + serde::Serialize,
132{
133 pub fn new(
135 model: crate::models::sequential::Sequential<F>,
136 config: VisualizationConfig,
137 ) -> Self {
138 let training = TrainingVisualizer::new(config.clone());
139 let activation = ActivationVisualizer::new(
141 crate::models::sequential::Sequential::default(),
142 config.clone(),
143 );
144 let attention = AttentionVisualizer::new(
145 crate::models::sequential::Sequential::default(),
146 config.clone(),
147 );
148 let network = NetworkVisualizer::new(model, config.clone());
149
150 Self {
151 network,
152 training,
153 activation,
154 attention,
155 config,
156 }
157 }
158
159 pub fn update_config(&mut self, config: VisualizationConfig) {
161 self.config = config.clone();
162 self.network.update_config(config.clone());
163 self.training.update_config(config.clone());
164 self.activation.update_config(config.clone());
165 self.attention.update_config(config);
166 }
167
168 pub fn get_config(&self) -> &VisualizationConfig {
170 &self.config
171 }
172
173 pub fn clear_all_caches(&mut self) {
175 self.network.clear_cache();
176 self.training.clear_history();
177 self.activation.clear_cache();
178 self.attention.clear_cache();
179 }
180}
181
182pub struct VisualizationConfigBuilder {
184 config: VisualizationConfig,
185}
186
187impl VisualizationConfigBuilder {
188 pub fn new() -> Self {
190 Self {
191 config: VisualizationConfig::default(),
192 }
193 }
194
195 pub fn output_dir<P: Into<std::path::PathBuf>>(mut self, path: P) -> Self {
197 self.config.output_dir = path.into();
198 self
199 }
200
201 pub fn image_format(mut self, format: ImageFormat) -> Self {
203 self.config.image_format = format;
204 self
205 }
206
207 pub fn color_palette(mut self, palette: ColorPalette) -> Self {
209 self.config.style.color_palette = palette;
210 self
211 }
212
213 pub fn theme(mut self, theme: Theme) -> Self {
215 self.config.style.theme = theme;
216 self
217 }
218
219 pub fn interactive(mut self, enable: bool) -> Self {
221 self.config.interactive.enable_interaction = enable;
222 self
223 }
224
225 pub fn canvas_size(mut self, width: u32, height: u32) -> Self {
227 self.config.style.layout.width = width;
228 self.config.style.layout.height = height;
229 self
230 }
231
232 pub fn max_points(mut self, max_points: usize) -> Self {
234 self.config.performance.max_points_per_plot = max_points;
235 self
236 }
237
238 pub fn downsampling(mut self, strategy: DownsamplingStrategy) -> Self {
240 self.config.performance.enable_downsampling = true;
241 self.config.performance.downsampling_strategy = strategy;
242 self
243 }
244
245 pub fn build(self) -> VisualizationConfig {
247 self.config
248 }
249}
250
251impl Default for VisualizationConfigBuilder {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257pub mod utils {
259 use super::*;
260
261 pub fn quick_config() -> VisualizationConfig {
263 VisualizationConfigBuilder::new()
264 .canvas_size(800, 600)
265 .theme(Theme::Light)
266 .color_palette(ColorPalette::Default)
267 .interactive(false)
268 .build()
269 }
270
271 pub fn publication_config() -> VisualizationConfig {
273 VisualizationConfigBuilder::new()
274 .canvas_size(1920, 1080)
275 .image_format(ImageFormat::PDF)
276 .theme(Theme::Light)
277 .color_palette(ColorPalette::ColorblindFriendly)
278 .interactive(false)
279 .build()
280 }
281
282 pub fn dashboard_config() -> VisualizationConfig {
284 VisualizationConfigBuilder::new()
285 .canvas_size(1200, 800)
286 .image_format(ImageFormat::HTML)
287 .theme(Theme::Dark)
288 .color_palette(ColorPalette::HighContrast)
289 .interactive(true)
290 .max_points(50000)
291 .downsampling(DownsamplingStrategy::LTTB)
292 .build()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::layers::Dense;
300 use rand::SeedableRng;
301
302 #[test]
303 fn test_visualization_suite_creation() {
304 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
305 let mut model = crate::models::sequential::Sequential::<f32>::new();
306 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).unwrap());
307
308 let config = VisualizationConfig::default();
309 let _suite = VisualizationSuite::new(model, config);
310
311 }
313
314 #[test]
315 fn test_config_builder() {
316 let config = VisualizationConfigBuilder::new()
317 .canvas_size(1920, 1080)
318 .theme(Theme::Dark)
319 .color_palette(ColorPalette::HighContrast)
320 .interactive(true)
321 .build();
322
323 assert_eq!(config.style.layout.width, 1920);
324 assert_eq!(config.style.layout.height, 1080);
325 assert_eq!(config.style.theme, Theme::Dark);
326 assert_eq!(config.style.color_palette, ColorPalette::HighContrast);
327 assert!(config.interactive.enable_interaction);
328 }
329
330 #[test]
331 fn test_utility_configs() {
332 let quick = utils::quick_config();
333 assert_eq!(quick.style.layout.width, 800);
334 assert_eq!(quick.style.layout.height, 600);
335 assert_eq!(quick.style.theme, Theme::Light);
336 assert!(!quick.interactive.enable_interaction);
337
338 let publication = utils::publication_config();
339 assert_eq!(publication.image_format, ImageFormat::PDF);
340 assert_eq!(
341 publication.style.color_palette,
342 ColorPalette::ColorblindFriendly
343 );
344
345 let dashboard = utils::dashboard_config();
346 assert_eq!(dashboard.image_format, ImageFormat::HTML);
347 assert_eq!(dashboard.style.theme, Theme::Dark);
348 assert!(dashboard.interactive.enable_interaction);
349 }
350
351 #[test]
352 fn test_type_aliases() {
353 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
355 let mut model = crate::models::sequential::Sequential::<f32>::new();
356 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).unwrap());
357
358 let config = VisualizationConfig::default();
359
360 let _network_viz: NetworkViz<f32> = NetworkVisualizer::new(model.clone(), config.clone());
361 let _training_viz: TrainingViz<f32> = TrainingVisualizer::new(config.clone());
362 let _activation_viz: ActivationViz<f32> =
363 ActivationVisualizer::new(model.clone(), config.clone());
364 let _attention_viz: AttentionViz<f32> = AttentionVisualizer::new(model, config);
365 }
366
367 #[test]
368 fn test_suite_operations() {
369 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
370 let mut model = crate::models::sequential::Sequential::<f32>::new();
371 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).unwrap());
372
373 let config = VisualizationConfig::default();
374 let mut suite = VisualizationSuite::new(model, config.clone());
375
376 assert_eq!(
377 suite.get_config().style.layout.width,
378 config.style.layout.width
379 );
380
381 suite.clear_all_caches();
383
384 let new_config = VisualizationConfigBuilder::new()
386 .canvas_size(1024, 768)
387 .build();
388 suite.update_config(new_config);
389 assert_eq!(suite.get_config().style.layout.width, 1024);
390 }
391
392 #[test]
393 fn test_module_integration() {
394 use super::activations::*;
396 use super::attention::*;
397 use super::config::*;
398 use super::network::*;
399 use super::training::*;
400
401 let _viz_config = VisualizationConfig::default();
403 let _plot_config = PlotConfig::default();
404 let _activation_options = ActivationVisualizationOptions::default();
405 let _attention_options = AttentionVisualizationOptions::default();
406 let _export_options = ExportOptions::default();
407
408 let _image_format = ImageFormat::SVG;
410 let _layout_algo = LayoutAlgorithm::Hierarchical;
411 let _plot_type = PlotType::Line;
412 let _colormap = Colormap::Viridis;
413 let _attention_type = AttentionVisualizationType::Heatmap;
414 }
415}