Skip to main content

scirs2_neural/visualization/
network.rs

1//! Network architecture visualization for neural networks
2//!
3//! This module provides comprehensive tools for visualizing neural network architectures
4//! including layout algorithms, rendering capabilities, and interactive features.
5
6use super::config::{ImageFormat, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8use crate::models::sequential::Sequential;
9use scirs2_core::numeric::{Float, NumAssign};
10use serde::Serialize;
11use std::fmt::Debug;
12use std::fs;
13use std::path::PathBuf;
14/// Network architecture visualizer
15#[allow(dead_code)]
16pub struct NetworkVisualizer<F: Float + Debug + scirs2_core::ndarray::ScalarOperand + NumAssign> {
17    /// Model to visualize
18    model: Sequential<F>,
19    /// Visualization configuration
20    config: VisualizationConfig,
21    /// Cached layout information
22    layout_cache: Option<NetworkLayout>,
23}
24/// Network layout information
25#[derive(Debug, Clone, Serialize)]
26pub struct NetworkLayout {
27    /// Layer positions
28    pub layer_positions: Vec<LayerPosition>,
29    /// Connection information
30    pub connections: Vec<Connection>,
31    /// Bounding box
32    pub bounds: BoundingBox,
33    /// Layout algorithm used
34    pub algorithm: LayoutAlgorithm,
35}
36
37/// Layer position in the visualization
38#[derive(Debug, Clone, Serialize)]
39pub struct LayerPosition {
40    /// Layer name/identifier
41    pub name: String,
42    /// Layer type
43    pub layer_type: String,
44    /// Position coordinates
45    pub position: Point2D,
46    /// Layer dimensions
47    pub size: Size2D,
48    /// Input/output information
49    pub io_info: LayerIOInfo,
50    /// Visual properties
51    pub visual_props: LayerVisualProps,
52}
53
54/// Point in 2D space
55#[derive(Debug, Clone, Serialize)]
56pub struct Point2D {
57    /// X coordinate
58    pub x: f32,
59    /// Y coordinate
60    pub y: f32,
61}
62
63/// Size in 2D space
64#[derive(Debug, Clone, Serialize)]
65pub struct Size2D {
66    /// Width
67    pub width: f32,
68    /// Height
69    pub height: f32,
70}
71
72/// Layer input/output information
73#[derive(Debug, Clone, Serialize)]
74pub struct LayerIOInfo {
75    /// Input shape
76    pub inputshape: Vec<usize>,
77    /// Output shape
78    pub outputshape: Vec<usize>,
79    /// Parameter count
80    pub parameter_count: usize,
81    /// Computation complexity (FLOPs)
82    pub flops: u64,
83}
84
85/// Visual properties for layer rendering
86#[derive(Debug, Clone, Serialize)]
87pub struct LayerVisualProps {
88    /// Fill color
89    pub fill_color: String,
90    /// Border color
91    pub border_color: String,
92    /// Border width
93    pub border_width: f32,
94    /// Opacity
95    pub opacity: f32,
96    /// Layer icon/symbol
97    pub icon: Option<String>,
98}
99
100/// Connection between layers
101#[derive(Debug, Clone, Serialize)]
102pub struct Connection {
103    /// Source layer index
104    pub from_layer: usize,
105    /// Target layer index
106    pub to_layer: usize,
107    /// Connection type
108    pub connection_type: ConnectionType,
109    /// Visual properties
110    pub visual_props: ConnectionVisualProps,
111    /// Data flow information
112    pub data_flow: DataFlowInfo,
113}
114
115/// Type of connection between layers
116#[derive(Debug, Clone, PartialEq, Serialize)]
117pub enum ConnectionType {
118    /// Standard forward connection
119    Forward,
120    /// Skip/residual connection
121    Skip,
122    /// Attention connection
123    Attention,
124    /// Recurrent connection
125    Recurrent,
126    /// Sequential connection
127    Sequential,
128    /// Lateral connection
129    Lateral,
130    /// Custom connection
131    Custom(String),
132}
133
134/// Visual properties for connection rendering
135#[derive(Debug, Clone, Serialize)]
136pub struct ConnectionVisualProps {
137    /// Line color
138    pub color: String,
139    /// Line width
140    pub width: f32,
141    /// Line style
142    pub style: LineStyle,
143    /// Arrow style
144    pub arrow: ArrowStyle,
145    /// Opacity
146    pub opacity: f32,
147}
148
149/// Line style for connections
150#[derive(Debug, Clone, PartialEq, Serialize)]
151pub enum LineStyle {
152    /// Solid line
153    Solid,
154    /// Dashed line
155    Dashed,
156    /// Dotted line
157    Dotted,
158    /// Dash-dot line
159    DashDot,
160}
161
162/// Arrow style for connections
163#[derive(Debug, Clone, PartialEq, Serialize)]
164pub enum ArrowStyle {
165    /// No arrow
166    None,
167    /// Simple arrow
168    Simple,
169    /// Block arrow
170    Block,
171    /// Curved arrow
172    Curved,
173}
174
175/// Data flow information
176#[derive(Debug, Clone, Serialize)]
177pub struct DataFlowInfo {
178    /// Tensor shape flowing through connection
179    pub tensorshape: Vec<usize>,
180    /// Data type
181    pub data_type: String,
182    /// Estimated memory usage in bytes
183    pub memory_usage: usize,
184    /// Batch size for data flow
185    pub batch_size: Option<usize>,
186    /// Throughput information
187    pub throughput: Option<ThroughputInfo>,
188}
189
190/// Throughput information for data flow
191#[derive(Debug, Clone, Serialize)]
192pub struct ThroughputInfo {
193    /// Samples per second
194    pub samples_per_second: f64,
195    /// Bytes per second
196    pub bytes_per_second: u64,
197    /// Latency in milliseconds
198    pub latency_ms: f64,
199}
200
201/// Bounding box for layout
202#[derive(Debug, Clone, Serialize)]
203pub struct BoundingBox {
204    /// Minimum X coordinate
205    pub min_x: f32,
206    /// Minimum Y coordinate
207    pub min_y: f32,
208    /// Maximum X coordinate
209    pub max_x: f32,
210    /// Maximum Y coordinate
211    pub max_y: f32,
212}
213
214/// Layout algorithm for network visualization
215#[derive(Debug, Clone, PartialEq, Serialize)]
216pub enum LayoutAlgorithm {
217    /// Hierarchical layout (top-down)
218    Hierarchical,
219    /// Force-directed layout
220    ForceDirected,
221    /// Circular layout
222    Circular,
223    /// Grid layout
224    Grid,
225    /// Custom layout
226    Custom(String),
227}
228
229/// Layer information for analysis
230#[derive(Debug, Clone)]
231pub struct LayerInfo {
232    /// Layer name
233    pub layer_name: String,
234    /// Layer index
235    pub layer_index: usize,
236    /// Layer type
237    pub layer_type: String,
238}
239
240// Implementation for NetworkVisualizer
241impl<
242        F: Float
243            + Debug
244            + std::fmt::Display
245            + 'static
246            + scirs2_core::numeric::FromPrimitive
247            + scirs2_core::ndarray::ScalarOperand
248            + Send
249            + Sync
250            + NumAssign,
251    > NetworkVisualizer<F>
252{
253    /// Create a new network visualizer
254    pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
255        Self {
256            model,
257            config,
258            layout_cache: None,
259        }
260    }
261    /// Generate network architecture visualization
262    pub fn visualize_architecture(&mut self) -> Result<PathBuf> {
263        // Compute network layout
264        let layout = self.compute_layout()?;
265        self.layout_cache = Some(layout.clone());
266        // Generate visualization based on format
267        match self.config.image_format {
268            ImageFormat::SVG => self.generate_svg_visualization(&layout),
269            ImageFormat::HTML => self.generate_html_visualization(&layout),
270            ImageFormat::JSON => self.generate_json_visualization(&layout),
271            _ => self.generate_svg_visualization(&layout), // Default to SVG
272        }
273    }
274
275    /// Compute network layout using specified algorithm
276    fn compute_layout(&self) -> Result<NetworkLayout> {
277        // Analyze model structure
278        let layer_info = self.analyze_model_structure()?;
279        // Choose layout algorithm based on network complexity
280        let algorithm = self.select_layout_algorithm(&layer_info);
281        // Compute positions using selected algorithm
282        let (positions, connections) = match algorithm {
283            LayoutAlgorithm::Hierarchical => self.compute_hierarchical_layout(&layer_info)?,
284            LayoutAlgorithm::ForceDirected => self.compute_force_directed_layout(&layer_info)?,
285            LayoutAlgorithm::Circular => self.compute_circular_layout(&layer_info)?,
286            LayoutAlgorithm::Grid => self.compute_grid_layout(&layer_info)?,
287            LayoutAlgorithm::Custom(_) => self.compute_hierarchical_layout(&layer_info)?, // Fallback
288        };
289        // Compute bounding box
290        let bounds = self.compute_bounds(&positions);
291        Ok(NetworkLayout {
292            layer_positions: positions,
293            connections,
294            bounds,
295            algorithm,
296        })
297    }
298
299    fn analyze_model_structure(&self) -> Result<Vec<LayerInfo>> {
300        let mut layer_info = Vec::new();
301        // For Sequential models, we can iterate through the layers
302        let layers = self.model.layers();
303        for (index, layer) in layers.iter().enumerate() {
304            let layer_type = layer.layer_type().to_string();
305            let layer_name = format!("{layer_type}_{index}");
306            layer_info.push(LayerInfo {
307                layer_name,
308                layer_index: index,
309                layer_type,
310            });
311        }
312
313        // If no layers found, return error
314        if layer_info.is_empty() {
315            return Err(NeuralError::InvalidArgument(
316                "Model has no layers".to_string(),
317            ));
318        }
319
320        Ok(layer_info)
321    }
322
323    fn select_layout_algorithm(&self, _layer_info: &[LayerInfo]) -> LayoutAlgorithm {
324        // For now, default to hierarchical layout
325        // In a full implementation, this would analyze the network structure
326        // and choose the most appropriate layout algorithm
327        LayoutAlgorithm::Hierarchical
328    }
329
330    fn compute_hierarchical_layout(
331        &self,
332        layer_info: &[LayerInfo],
333    ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
334        if layer_info.is_empty() {
335            return Ok((Vec::new(), Vec::new()));
336        }
337
338        let mut positions = Vec::new();
339        let mut connections = Vec::new();
340        // Layout parameters
341        let layer_width = 120.0;
342        let layer_height = 60.0;
343        let vertical_spacing = 100.0;
344        let horizontal_spacing = 150.0;
345        // Calculate total width and starting position
346        let total_width = (layer_info.len() as f32 - 1.0) * horizontal_spacing + layer_width;
347        let start_x = -total_width / 2.0 + layer_width / 2.0;
348        let start_y = -(layer_info.len() as f32 - 1.0) * vertical_spacing / 2.0;
349        // Create layer positions
350        for (i, layer) in layer_info.iter().enumerate() {
351            let x = start_x;
352            let y = start_y + i as f32 * vertical_spacing;
353            // Determine layer visual properties based on type
354            let (fill_color, border_color, icon) = match layer.layer_type.as_str() {
355                "Dense" => (
356                    "#4CAF50".to_string(),
357                    "#2E7D32".to_string(),
358                    Some("◯".to_string()),
359                ),
360                "Conv2D" => (
361                    "#2196F3".to_string(),
362                    "#1565C0".to_string(),
363                    Some("⬜".to_string()),
364                ),
365                "Conv1D" => (
366                    "#03A9F4".to_string(),
367                    "#0277BD".to_string(),
368                    Some("▬".to_string()),
369                ),
370                "MaxPool2D" | "AvgPool2D" => (
371                    "#FF9800".to_string(),
372                    "#E65100".to_string(),
373                    Some("▣".to_string()),
374                ),
375                "Dropout" => (
376                    "#9C27B0".to_string(),
377                    "#6A1B9A".to_string(),
378                    Some("×".to_string()),
379                ),
380                "BatchNorm" => (
381                    "#607D8B".to_string(),
382                    "#37474F".to_string(),
383                    Some("∼".to_string()),
384                ),
385                "Activation" => (
386                    "#FFC107".to_string(),
387                    "#F57C00".to_string(),
388                    Some("∘".to_string()),
389                ),
390                "LSTM" => (
391                    "#E91E63".to_string(),
392                    "#AD1457".to_string(),
393                    Some("⟲".to_string()),
394                ),
395                "GRU" => (
396                    "#F44336".to_string(),
397                    "#C62828".to_string(),
398                    Some("⟳".to_string()),
399                ),
400                "Attention" => (
401                    "#673AB7".to_string(),
402                    "#4527A0".to_string(),
403                    Some("◉".to_string()),
404                ),
405                _ => (
406                    "#9E9E9E".to_string(),
407                    "#424242".to_string(),
408                    Some("?".to_string()),
409                ),
410            };
411            // Estimate parameter count (simplified)
412            let parameter_count = match layer.layer_type.as_str() {
413                "Dense" => 10000, // Placeholder
414                "Conv2D" => 5000,
415                "Conv1D" => 3000,
416                _ => 0,
417            };
418
419            // Estimate FLOPs (simplified)
420            let flops = match layer.layer_type.as_str() {
421                "Dense" => 100000,
422                "Conv2D" => 500000,
423                "Conv1D" => 200000,
424                _ => 1000,
425            };
426
427            let position = LayerPosition {
428                name: layer.layer_name.clone(),
429                layer_type: layer.layer_type.clone(),
430                position: Point2D { x, y },
431                size: Size2D {
432                    width: layer_width,
433                    height: layer_height,
434                },
435                io_info: LayerIOInfo {
436                    inputshape: vec![32, 32, 3],  // Placeholder
437                    outputshape: vec![32, 32, 3], // Placeholder
438                    parameter_count,
439                    flops,
440                },
441                visual_props: LayerVisualProps {
442                    fill_color,
443                    border_color,
444                    border_width: 2.0,
445                    opacity: 0.9,
446                    icon,
447                },
448            };
449
450            positions.push(position);
451        }
452
453        // Create connections between adjacent layers
454        for i in 0..(layer_info.len().saturating_sub(1)) {
455            let connection = Connection {
456                from_layer: i,
457                to_layer: i + 1,
458                connection_type: ConnectionType::Forward,
459                visual_props: ConnectionVisualProps {
460                    color: "#666666".to_string(),
461                    width: 2.0,
462                    style: LineStyle::Solid,
463                    arrow: ArrowStyle::Simple,
464                    opacity: 0.8,
465                },
466                data_flow: DataFlowInfo {
467                    tensorshape: vec![32, 32, 3], // Placeholder
468                    data_type: "f32".to_string(),
469                    memory_usage: 4096,   // Placeholder
470                    batch_size: Some(32), // Default batch size
471                    throughput: Some(ThroughputInfo {
472                        samples_per_second: 1000.0,
473                        bytes_per_second: 4096000,
474                        latency_ms: 1.0,
475                    }),
476                },
477            };
478
479            connections.push(connection);
480        }
481
482        Ok((positions, connections))
483    }
484
485    fn compute_force_directed_layout(
486        &self,
487        layer_info: &[LayerInfo],
488    ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
489        if layer_info.is_empty() {
490            return Ok((Vec::new(), Vec::new()));
491        }
492
493        let mut positions = Vec::new();
494        let mut connections = Vec::new();
495        // Force-directed layout parameters
496        let area = 800.0 * 600.0; // Canvas area
497        let k = (area / layer_info.len() as f32).sqrt(); // Optimal distance
498        let iterations = 100;
499        let cooling_factor = 0.95;
500        let mut temperature = 100.0;
501        // Initialize random positions
502        let mut node_positions: Vec<Point2D> = (0..layer_info.len())
503            .map(|i| Point2D {
504                x: ((i % 4) as f32 - 1.5) * 100.0, // Rough grid start
505                y: ((i / 4) as f32 - 1.5) * 100.0,
506            })
507            .collect();
508        // Force-directed algorithm iterations
509        for _iteration in 0..iterations {
510            let mut forces: Vec<Point2D> = vec![Point2D { x: 0.0, y: 0.0 }; layer_info.len()];
511            // Calculate repulsive forces between all pairs
512            for i in 0..layer_info.len() {
513                for j in (i + 1)..layer_info.len() {
514                    let dx = node_positions[i].x - node_positions[j].x;
515                    let dy = node_positions[i].y - node_positions[j].y;
516                    let distance = (dx * dx + dy * dy).sqrt().max(1.0);
517                    let repulsive_force = k * k / distance;
518                    let fx = repulsive_force * dx / distance;
519                    let fy = repulsive_force * dy / distance;
520                    forces[i].x += fx;
521                    forces[i].y += fy;
522                    forces[j].x -= fx;
523                    forces[j].y -= fy;
524                }
525            }
526            // Calculate attractive forces for connected layers (sequential connections)
527            for i in 0..(layer_info.len() - 1) {
528                let dx = node_positions[i].x - node_positions[i + 1].x;
529                let dy = node_positions[i].y - node_positions[i + 1].y;
530                let distance = (dx * dx + dy * dy).sqrt().max(1.0);
531                let attractive_force = distance * distance / k;
532                let fx = attractive_force * dx / distance;
533                let fy = attractive_force * dy / distance;
534                forces[i].x -= fx;
535                forces[i].y -= fy;
536                forces[i + 1].x += fx;
537                forces[i + 1].y += fy;
538            }
539
540            // Apply forces with temperature cooling
541            for i in 0..layer_info.len() {
542                let force_magnitude =
543                    (forces[i].x * forces[i].x + forces[i].y * forces[i].y).sqrt();
544                if force_magnitude > 0.0 {
545                    let displacement = temperature.min(force_magnitude);
546                    node_positions[i].x += forces[i].x / force_magnitude * displacement;
547                    node_positions[i].y += forces[i].y / force_magnitude * displacement;
548                }
549            }
550
551            temperature *= cooling_factor;
552        }
553
554        // Create layer positions with visual properties
555        for (i, layer) in layer_info.iter().enumerate() {
556            let position = LayerPosition {
557                name: layer.layer_name.clone(),
558                layer_type: layer.layer_type.clone(),
559                position: node_positions[i].clone(),
560                size: Size2D {
561                    width: 120.0,
562                    height: 60.0,
563                },
564                io_info: LayerIOInfo {
565                    inputshape: vec![1, 32], // Placeholder
566                    outputshape: vec![1, 32],
567                    parameter_count: 1024,
568                    flops: 2048,
569                },
570                visual_props: LayerVisualProps {
571                    fill_color: "#8BC34A".to_string(),
572                    border_color: "#558B2F".to_string(),
573                    border_width: 2.0,
574                    opacity: 0.9,
575                    icon: Some("▢".to_string()),
576                },
577            };
578            positions.push(position);
579        }
580        // Create connections between sequential layers
581        for i in 0..(layer_info.len().saturating_sub(1)) {
582            let connection = Connection {
583                from_layer: i,
584                to_layer: i + 1,
585                connection_type: ConnectionType::Sequential,
586                visual_props: ConnectionVisualProps {
587                    color: "#666666".to_string(),
588                    width: 2.0,
589                    style: LineStyle::Solid,
590                    arrow: ArrowStyle::Simple,
591                    opacity: 0.7,
592                },
593                data_flow: DataFlowInfo {
594                    tensorshape: vec![1, 32],
595                    data_type: "float32".to_string(),
596                    memory_usage: 128, // 1 * 32 * 4 bytes
597                    batch_size: Some(1),
598                    throughput: None,
599                },
600            };
601            connections.push(connection);
602        }
603
604        Ok((positions, connections))
605    }
606
607    fn compute_circular_layout(
608        &self,
609        layer_info: &[LayerInfo],
610    ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
611        if layer_info.is_empty() {
612            return Ok((Vec::new(), Vec::new()));
613        }
614
615        let mut positions = Vec::new();
616        let mut connections = Vec::new();
617        // Circular layout parameters
618        let radius = if layer_info.len() == 1 {
619            50.0
620        } else {
621            // Calculate radius to ensure layers don't overlap
622            let circumference = layer_info.len() as f32 * 150.0; // 150px minimum spacing
623            circumference / (2.0 * std::f32::consts::PI)
624        };
625
626        let center_x = 0.0;
627        let center_y = 0.0;
628
629        // Create layer positions around the circle
630        for (i, layer) in layer_info.iter().enumerate() {
631            let angle = if layer_info.len() == 1 {
632                0.0
633            } else {
634                2.0 * std::f32::consts::PI * i as f32 / layer_info.len() as f32
635            };
636
637            let x = center_x + radius * angle.cos();
638            let y = center_y + radius * angle.sin();
639
640            let position = LayerPosition {
641                name: layer.layer_name.clone(),
642                layer_type: layer.layer_type.clone(),
643                position: Point2D { x, y },
644                size: Size2D {
645                    width: 120.0,
646                    height: 60.0,
647                },
648                io_info: LayerIOInfo {
649                    inputshape: vec![1, 32],
650                    outputshape: vec![1, 32],
651                    parameter_count: 1024,
652                    flops: 2048,
653                },
654                visual_props: LayerVisualProps {
655                    fill_color: "#FF9800".to_string(),
656                    border_color: "#E65100".to_string(),
657                    border_width: 2.0,
658                    opacity: 0.9,
659                    icon: Some("⭕".to_string()),
660                },
661            };
662            positions.push(position);
663        }
664
665        // Create connections between sequential layers
666        for i in 0..(layer_info.len().saturating_sub(1)) {
667            let connection = Connection {
668                from_layer: i,
669                to_layer: i + 1,
670                connection_type: ConnectionType::Forward,
671                visual_props: ConnectionVisualProps {
672                    color: "#666666".to_string(),
673                    width: 2.0,
674                    style: LineStyle::Solid,
675                    arrow: ArrowStyle::Simple,
676                    opacity: 0.7,
677                },
678                data_flow: DataFlowInfo {
679                    tensorshape: vec![1, 32],
680                    data_type: "float32".to_string(),
681                    memory_usage: 128,
682                    batch_size: Some(1),
683                    throughput: None,
684                },
685            };
686            connections.push(connection);
687        }
688        // For circular layout, also connect the last layer back to the first (if more than 2 layers)
689        if layer_info.len() > 2 {
690            let connection = Connection {
691                from_layer: layer_info.len() - 1,
692                to_layer: 0,
693                connection_type: ConnectionType::Recurrent,
694                visual_props: ConnectionVisualProps {
695                    color: "#999999".to_string(),
696                    width: 1.5,
697                    style: LineStyle::Dashed,
698                    arrow: ArrowStyle::Simple,
699                    opacity: 0.5,
700                },
701                data_flow: DataFlowInfo {
702                    tensorshape: vec![1, 32],
703                    data_type: "float32".to_string(),
704                    memory_usage: 128,
705                    batch_size: Some(1),
706                    throughput: None,
707                },
708            };
709            connections.push(connection);
710        }
711
712        Ok((positions, connections))
713    }
714
715    fn compute_grid_layout(
716        &self,
717        layer_info: &[LayerInfo],
718    ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
719        if layer_info.is_empty() {
720            return Ok((Vec::new(), Vec::new()));
721        }
722
723        let mut positions = Vec::new();
724        let mut connections = Vec::new();
725        // Grid layout parameters
726        let cell_width = 180.0;
727        let cell_height = 120.0;
728        let margin = 20.0;
729        // Calculate grid dimensions (prefer square or wide rectangle)
730        let total_layers = layer_info.len();
731        let grid_cols = (total_layers as f32).sqrt().ceil() as usize;
732        let grid_rows = (total_layers as f32 / grid_cols as f32).ceil() as usize;
733        // Calculate starting position to center the grid
734        let total_width = grid_cols as f32 * cell_width;
735        let total_height = grid_rows as f32 * cell_height;
736        let start_x = -total_width / 2.0 + cell_width / 2.0;
737        let start_y = -total_height / 2.0 + cell_height / 2.0;
738        // Create layer positions in grid formation
739        for (i, layer) in layer_info.iter().enumerate() {
740            let col = i % grid_cols;
741            let row = i / grid_cols;
742            let x = start_x + col as f32 * cell_width;
743            let y = start_y + row as f32 * cell_height;
744
745            let position = LayerPosition {
746                name: layer.layer_name.clone(),
747                layer_type: layer.layer_type.clone(),
748                position: Point2D { x, y },
749                size: Size2D {
750                    width: cell_width - margin,
751                    height: cell_height - margin,
752                },
753                io_info: LayerIOInfo {
754                    inputshape: vec![1, 32],
755                    outputshape: vec![1, 32],
756                    parameter_count: 1024,
757                    flops: 2048,
758                },
759                visual_props: LayerVisualProps {
760                    fill_color: "#2196F3".to_string(),
761                    border_color: "#1565C0".to_string(),
762                    border_width: 2.0,
763                    opacity: 0.9,
764                    icon: Some("⬜".to_string()),
765                },
766            };
767            positions.push(position);
768        }
769
770        // Create connections between sequential layers
771        for i in 0..(layer_info.len().saturating_sub(1)) {
772            let from_col = i % grid_cols;
773            let from_row = i / grid_cols;
774            let to_col = (i + 1) % grid_cols;
775            let to_row = (i + 1) / grid_cols;
776
777            // Determine connection visual style based on grid position relationship
778            let (color, style, width) = if from_row == to_row {
779                // Same row - horizontal connection
780                ("#4CAF50".to_string(), LineStyle::Solid, 2.5)
781            } else if from_col == to_col {
782                // Same column - vertical connection
783                ("#2196F3".to_string(), LineStyle::Solid, 2.5)
784            } else {
785                // Diagonal connection
786                ("#FF9800".to_string(), LineStyle::Dashed, 2.0)
787            };
788
789            let connection = Connection {
790                from_layer: i,
791                to_layer: i + 1,
792                connection_type: ConnectionType::Forward,
793                visual_props: ConnectionVisualProps {
794                    color,
795                    width,
796                    style,
797                    arrow: ArrowStyle::Simple,
798                    opacity: 0.7,
799                },
800                data_flow: DataFlowInfo {
801                    tensorshape: vec![1, 32],
802                    data_type: "float32".to_string(),
803                    memory_usage: 128,
804                    batch_size: Some(1),
805                    throughput: None,
806                },
807            };
808            connections.push(connection);
809        }
810
811        // Add some additional connections for grid pattern visualization
812        // Connect layers in the same row (if there are multiple rows)
813        if grid_rows > 1 {
814            for row in 0..grid_rows {
815                for col in 0..(grid_cols - 1) {
816                    let from_idx = row * grid_cols + col;
817                    let to_idx = row * grid_cols + col + 1;
818                    if from_idx < total_layers && to_idx < total_layers && from_idx + 1 != to_idx {
819                        let connection = Connection {
820                            from_layer: from_idx,
821                            to_layer: to_idx,
822                            connection_type: ConnectionType::Lateral,
823                            data_flow: DataFlowInfo {
824                                tensorshape: vec![1, 16],
825                                data_type: "float32".to_string(),
826                                memory_usage: 64, // 1 * 16 * 4 bytes
827                                batch_size: Some(1),
828                                throughput: None,
829                            },
830                            visual_props: ConnectionVisualProps {
831                                color: "#9E9E9E".to_string(),
832                                width: 1.0,
833                                style: LineStyle::Dotted,
834                                arrow: ArrowStyle::None,
835                                opacity: 0.4,
836                            },
837                        };
838                        connections.push(connection);
839                    }
840                }
841            }
842        }
843
844        Ok((positions, connections))
845    }
846
847    fn compute_bounds(&self, positions: &[LayerPosition]) -> BoundingBox {
848        if positions.is_empty() {
849            return BoundingBox {
850                min_x: 0.0,
851                min_y: 0.0,
852                max_x: 100.0,
853                max_y: 100.0,
854            };
855        }
856
857        let mut min_x = f32::INFINITY;
858        let mut min_y = f32::INFINITY;
859        let mut max_x = f32::NEG_INFINITY;
860        let mut max_y = f32::NEG_INFINITY;
861        for pos in positions {
862            min_x = min_x.min(pos.position.x - pos.size.width / 2.0);
863            min_y = min_y.min(pos.position.y - pos.size.height / 2.0);
864            max_x = max_x.max(pos.position.x + pos.size.width / 2.0);
865            max_y = max_y.max(pos.position.y + pos.size.height / 2.0);
866        }
867
868        BoundingBox {
869            min_x,
870            min_y,
871            max_x,
872            max_y,
873        }
874    }
875
876    fn generate_svg_visualization(&self, layout: &NetworkLayout) -> Result<PathBuf> {
877        let output_path = self.config.output_dir.join("network_architecture.svg");
878        // Generate SVG content
879        let svg_content = self.create_svg_content(layout)?;
880        // Write to file
881        fs::write(&output_path, svg_content)
882            .map_err(|e| NeuralError::IOError(format!("Failed to write SVG file: {e}")))?;
883
884        Ok(output_path)
885    }
886
887    fn generate_html_visualization(&self, layout: &NetworkLayout) -> Result<PathBuf> {
888        let output_path = self.config.output_dir.join("network_architecture.html");
889        // Generate HTML content with interactive features
890        let html_content = self.create_html_content(layout)?;
891
892        fs::write(&output_path, html_content)
893            .map_err(|e| NeuralError::IOError(format!("Failed to write HTML file: {e}")))?;
894
895        Ok(output_path)
896    }
897
898    fn generate_json_visualization(&self, layout: &NetworkLayout) -> Result<PathBuf> {
899        let output_path = self.config.output_dir.join("network_architecture.json");
900        // Serialize layout to JSON
901        let json_content = serde_json::to_string_pretty(&layout).map_err(|e| {
902            NeuralError::SerializationError(format!("Failed to serialize layout: {e}"))
903        })?;
904
905        fs::write(&output_path, json_content)
906            .map_err(|e| NeuralError::IOError(format!("Failed to write JSON file: {e}")))?;
907
908        Ok(output_path)
909    }
910
911    fn create_svg_content(&self, layout: &NetworkLayout) -> Result<String> {
912        let bounds = &layout.bounds;
913        let margin = 50.0;
914        // Calculate SVG dimensions
915        let svg_width = (bounds.max_x - bounds.min_x + 2.0 * margin) as u32;
916        let svg_height = (bounds.max_y - bounds.min_y + 2.0 * margin) as u32;
917        // Calculate viewBox to center the network
918        let viewbox_x = bounds.min_x - margin;
919        let viewbox_y = bounds.min_y - margin;
920        let viewbox_width = bounds.max_x - bounds.min_x + 2.0 * margin;
921        let viewbox_height = bounds.max_y - bounds.min_y + 2.0 * margin;
922        let mut svg = format!(
923            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
924<svg width=\"{}\" height=\"{}\" viewBox=\"{} {} {} {}\" xmlns=\"http://www.w3.org/2000/svg\">\n\
925  <title>Neural Network Architecture</title>\n\
926  <defs>\n\
927    <style>\n\
928      .layer-rect {{ stroke-width: 2; }}\n\
929      .connection {{ fill: none; marker-end: url(#arrowhead); }}\n\
930      .layer-text {{ font-family: Arial, sans-serif; font-size: 11px; text-anchor: middle; fill: white; font-weight: bold; }}\n\
931      .layer-info {{ font-family: Arial, sans-serif; font-size: 9px; text-anchor: middle; fill: #333; }}\n\
932      .layer-icon {{ font-family: Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: white; font-weight: bold; }}\n\
933    </style>\n\
934    <marker id=\"arrowhead\" markerWidth=\"10\" markerHeight=\"7\" refX=\"10\" refY=\"3.5\" orient=\"auto\">\n\
935      <polygon points=\"0 0, 10 3.5, 0 7\" fill=\"#{}\"/>\n\
936    </marker>\n\
937  </defs>\n\
938  \n\
939  <!-- Background -->\n\
940  <rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"#{}\" stroke=\"#{}\"/>\n\
941  \n",
942            svg_width, svg_height, viewbox_x, viewbox_y, viewbox_width, viewbox_height,
943            "666666",
944            viewbox_x, viewbox_y, viewbox_width, viewbox_height, "f8f9fa", "dee2e6"
945        );
946        // Draw connections first (so they appear behind layers)
947        for connection in &layout.connections {
948            if connection.from_layer < layout.layer_positions.len()
949                && connection.to_layer < layout.layer_positions.len()
950            {
951                let from_pos = &layout.layer_positions[connection.from_layer];
952                let to_pos = &layout.layer_positions[connection.to_layer];
953                // Calculate connection points (bottom of source to top of target)
954                let x1 = from_pos.position.x;
955                let y1 = from_pos.position.y + from_pos.size.height / 2.0;
956                let x2 = to_pos.position.x;
957                let y2 = to_pos.position.y - to_pos.size.height / 2.0;
958                let stroke_width = connection.visual_props.width;
959                let stroke_color = &connection.visual_props.color;
960                let opacity = connection.visual_props.opacity;
961                svg.push_str(&format!(
962                    r#"  <line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="{}" opacity="{}" class="connection"/>
963"#,
964                    x1, y1, x2, y2, stroke_color, stroke_width, opacity
965                ));
966            }
967        }
968
969        // Draw layers
970        for (i, layer_pos) in layout.layer_positions.iter().enumerate() {
971            let x = layer_pos.position.x - layer_pos.size.width / 2.0;
972            let y = layer_pos.position.y - layer_pos.size.height / 2.0;
973            let width = layer_pos.size.width;
974            let height = layer_pos.size.height;
975            let fill_color = &layer_pos.visual_props.fill_color;
976            let border_color = &layer_pos.visual_props.border_color;
977            let border_width = layer_pos.visual_props.border_width;
978            let opacity = layer_pos.visual_props.opacity;
979            // Draw layer rectangle
980            svg.push_str(&format!(
981                r#"  <rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="{}" stroke-width="{}" opacity="{}" rx="5" class="layer-rect"/>
982"#,
983                x, y, width, height, fill_color, border_color, border_width, opacity
984            ));
985
986            // Draw layer icon if available
987            if let Some(ref icon) = layer_pos.visual_props.icon {
988                svg.push_str(&format!(
989                    r#"  <text x="{}" y="{}" class="layer-icon">{}</text>
990"#,
991                    layer_pos.position.x,
992                    layer_pos.position.y - 5.0,
993                    icon
994                ));
995            }
996
997            // Draw layer name
998            svg.push_str(&format!(
999                r#"  <text x="{}" y="{}" class="layer-text">{}</text>
1000"#,
1001                layer_pos.position.x,
1002                layer_pos.position.y + 8.0,
1003                layer_pos.layer_type
1004            ));
1005            // Draw parameter info below the layer
1006            let param_text = if layer_pos.io_info.parameter_count > 0 {
1007                format!("{}K params", layer_pos.io_info.parameter_count / 1000)
1008            } else {
1009                "No params".to_string()
1010            };
1011
1012            svg.push_str(&format!(
1013                r#"  <text x="{}" y="{}" class="layer-info">{}</text>
1014"#,
1015                layer_pos.position.x,
1016                y + height + 15.0,
1017                param_text
1018            ));
1019
1020            // Draw layer index
1021            svg.push_str(&format!(
1022                r#"  <text x="{}" y="{}" class="layer-info">Layer {}</text>
1023"#,
1024                layer_pos.position.x,
1025                y - 10.0,
1026                i
1027            ));
1028        }
1029        // Add legend
1030        let legend_x = viewbox_x + 10.0;
1031        let legend_y = viewbox_y + viewbox_height - 100.0;
1032        svg.push_str(&format!(
1033            "  <!-- Legend -->\n\
1034  <rect x=\"{}\" y=\"{}\" width=\"200\" height=\"80\" fill=\"white\" stroke=\"#{}\" stroke-width=\"1\" opacity=\"0.9\" rx=\"5\"/>\n\
1035  <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"12\" font-weight=\"bold\" fill=\"#333\">Legend</text>\n\
1036  <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">◯ Dense Layer</text>\n\
1037  <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">⬜ Conv2D Layer</text>\n\
1038  <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">× Dropout Layer</text>\n\
1039  <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">∼ BatchNorm Layer</text>\n",
1040            legend_x, legend_y, "ccc",
1041            legend_x + 10.0, legend_y + 15.0,
1042            legend_x + 10.0, legend_y + 30.0,
1043            legend_x + 10.0, legend_y + 45.0,
1044            legend_x + 10.0, legend_y + 60.0,
1045            legend_x + 10.0, legend_y + 75.0
1046        ));
1047
1048        svg.push_str("</svg>");
1049
1050        Ok(svg)
1051    }
1052
1053    fn create_html_content(&self, layout: &NetworkLayout) -> Result<String> {
1054        // Generate SVG content for embedding
1055        let svg_content = self.create_svg_content(layout)?;
1056        // Create the interactive HTML with embedded SVG and JavaScript controls
1057        let html_content = format!(
1058            r#"<!DOCTYPE html>
1059<html lang="en">
1060<head>
1061    <meta charset="UTF-8">
1062    <meta name="viewport" content="width=device-width, initial-scale=1.0">
1063    <title>Interactive Neural Network Architecture</title>
1064    <style>
1065        body {{
1066            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
1067            margin: 0;
1068            padding: 20px;
1069            background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
1070            color: #333;
1071        }}
1072        
1073        .header {{
1074            text-align: center;
1075            margin-bottom: 30px;
1076            background: white;
1077            border-radius: 10px;
1078            box-shadow: 0 4px 6px rgba(0,0,0,0.1);
1079        .controls {{
1080            margin-bottom: 20px;
1081        .control-group {{
1082            display: inline-block;
1083            margin-right: 20px;
1084            vertical-align: top;
1085        .control-group label {{
1086            display: block;
1087            font-weight: bold;
1088            margin-bottom: 5px;
1089            color: #555;
1090        button {{
1091            padding: 10px 20px;
1092            margin: 5px;
1093            border: none;
1094            border-radius: 5px;
1095            cursor: pointer;
1096            font-size: 14px;
1097            transition: all 0.3s ease;
1098            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1099            color: white;
1100        button:hover {{
1101            transform: translateY(-2px);
1102            box-shadow: 0 5px 15px rgba(0,0,0,0.2);
1103        button:active {{
1104            transform: translateY(0);
1105        select {{
1106            padding: 8px 12px;
1107            border: 1px solid #ddd;
1108        #visualization {{
1109            overflow: hidden;
1110            position: relative;
1111        #network-svg {{
1112            width: 100%;
1113            height: 700px;
1114            transition: transform 0.3s ease;
1115        .layer-node {{
1116        .layer-node:hover {{
1117            stroke-width: 3;
1118            filter: drop-shadow(0 4px 8px rgba(0,0,0,0.3));
1119        .connection-line {{
1120        .connection-line:hover {{
1121            stroke-width: 4;
1122            opacity: 1;
1123        .info-panel {{
1124            position: absolute;
1125            top: 10px;
1126            right: 10px;
1127            background: rgba(255,255,255,0.95);
1128            padding: 15px;
1129            border-radius: 8px;
1130            max-width: 300px;
1131            display: none;
1132        .info-panel h3 {{
1133            margin: 0 0 10px 0;
1134            color: #444;
1135        .info-panel p {{
1136            margin: 5px 0;
1137            font-size: 13px;
1138            color: #666;
1139        .layout-controls {{
1140            margin-bottom: 10px;
1141        .hidden {{
1142        .highlight {{
1143            stroke: #ff6b6b !important;
1144            stroke-width: 4 !important;
1145            filter: drop-shadow(0 0 10px #ff6b6b);
1146    </style>
1147</head>
1148<body>
1149    <div class="header">
1150        <h1>Interactive Neural Network Architecture</h1>
1151        <p>Algorithm: {algorithm} | Layers: {layer_count} | Connections: {connection_count}</p>
1152    </div>
1153    
1154    <div class="controls">
1155        <div class="control-group">
1156            <label>Zoom Controls:</label>
1157            <button onclick="zoomIn()">🔍+ Zoom In</button>
1158            <button onclick="zoomOut()">🔍- Zoom Out</button>
1159            <button onclick="resetView()">🎯 Reset View</button>
1160        </div>
1161            <label>Display Options:</label>
1162            <button onclick="toggleLabels()">🏷️ Toggle Labels</button>
1163            <button onclick="toggleConnections()">🔗 Toggle Connections</button>
1164            <button onclick="highlightPath()">⚡ Highlight Data Flow</button>
1165            <label>Layout Algorithm:</label>
1166            <select id="layoutSelect" onchange="changeLayout()">
1167                <option value="hierarchical">📊 Hierarchical</option>
1168                <option value="force-directed">🌟 Force-Directed</option>
1169                <option value="circular">⭕ Circular</option>
1170                <option value="grid">⬜ Grid</option>
1171            </select>
1172            <label>Animation:</label>
1173            <button onclick="animateDataFlow()">🎬 Animate Flow</button>
1174            <button onclick="showLayerDetails()">📋 Layer Details</button>
1175    <div id="visualization">
1176        <div id="network-svg-container">
1177            {svg_content}
1178        <div id="info-panel" class="info-panel">
1179            <h3 id="info-title">Layer Information</h3>
1180            <p><strong>Type:</strong> <span id="info-type">-</span></p>
1181            <p><strong>Input Shape:</strong> <span id="info-input">-</span></p>
1182            <p><strong>Output Shape:</strong> <span id="info-output">-</span></p>
1183            <p><strong>Parameters:</strong> <span id="info-params">-</span></p>
1184            <p><strong>FLOPs:</strong> <span id="info-flops">-</span></p>
1185    <script>
1186        // Global state
1187        let currentZoom = 1.0;
1188        let showLabels = true;
1189        let showConnections = true;
1190        let selectedLayer = null;
1191        let animationRunning = false;
1192        // SVG manipulation
1193        const svg = document.querySelector('#network-svg-container svg');
1194        const infoPanel = document.getElementById('info-panel');
1195        // Zoom functions
1196        function zoomIn() {{
1197            currentZoom = Math.min(currentZoom * 1.2, 3.0);
1198            updateZoom();
1199        function zoomOut() {{
1200            currentZoom = Math.max(currentZoom / 1.2, 0.3);
1201        function resetView() {{
1202            currentZoom = 1.0;
1203            clearHighlights();
1204            hideInfo();
1205        function updateZoom() {{
1206            if (svg) {{
1207                svg.style.transform = `scale(${{currentZoom}})`;
1208            }}
1209        // Label toggle
1210        function toggleLabels() {{
1211            showLabels = !showLabels;
1212            const labels = svg.querySelectorAll('text');
1213            labels.forEach(label => {{
1214                label.style.display = showLabels ? 'block' : 'none';
1215            }});
1216        // Connection toggle
1217        function toggleConnections() {{
1218            showConnections = !showConnections;
1219            const connections = svg.querySelectorAll('.connection-line, line[stroke]');
1220            connections.forEach(conn => {{
1221                conn.style.display = showConnections ? 'block' : 'none';
1222        // Highlight data flow path
1223        function highlightPath() {{
1224            const layers = svg.querySelectorAll('rect, circle, ellipse');
1225            const connections = svg.querySelectorAll('line[stroke], path[stroke]');
1226            
1227            // Sequential highlighting with delay
1228            layers.forEach((layer, index) => {{
1229                setTimeout(() => {{
1230                    layer.classList.add('highlight');
1231                    setTimeout(() => layer.classList.remove('highlight'), 1000);
1232                }}, index * 200);
1233            connections.forEach((conn, index) => {{
1234                    conn.classList.add('highlight');
1235                    setTimeout(() => conn.classList.remove('highlight'), 1000);
1236                }}, index * 200 + 100);
1237        // Animate data flow
1238        function animateDataFlow() {{
1239            if (animationRunning) return;
1240            animationRunning = true;
1241                    conn.style.strokeDasharray = '10,5';
1242                    conn.style.strokeDashoffset = '0';
1243                    conn.style.animation = 'flow 2s linear infinite';
1244                }}, index * 100);
1245            // Add CSS animation dynamically
1246            const style = document.createElement('style');
1247            style.textContent = `
1248                @keyframes flow {{
1249                    to {{ stroke-dashoffset: -15; }}
1250                }}
1251            `;
1252            document.head.appendChild(style);
1253            setTimeout(() => {{
1254                connections.forEach(conn => {{
1255                    conn.style.animation = '';
1256                    conn.style.strokeDasharray = '';
1257                    conn.style.strokeDashoffset = '';
1258                }});
1259                animationRunning = false;
1260            }}, 5000);
1261        // Layer details
1262        function showLayerDetails() {{
1263                layer.addEventListener('click', () => showLayerInfo(layer, index));
1264                layer.style.cursor = 'pointer';
1265        function showLayerInfo(layer, index) {{
1266            selectedLayer = layer;
1267            // Highlight selected layer
1268            layer.classList.add('highlight');
1269            // Show info panel with layer details
1270            document.getElementById('info-title').textContent = `Layer ${{index + 1}}`;
1271            document.getElementById('info-type').textContent = layer.getAttribute('data-type') || 'Unknown';
1272            document.getElementById('info-input').textContent = layer.getAttribute('data-input') || '[1, 32]';
1273            document.getElementById('info-output').textContent = layer.getAttribute('data-output') || '[1, 32]';
1274            document.getElementById('info-params').textContent = layer.getAttribute('data-params') || '1,024';
1275            document.getElementById('info-flops').textContent = layer.getAttribute('data-flops') || '2,048';
1276            infoPanel.style.display = 'block';
1277        function hideInfo() {{
1278            infoPanel.style.display = 'none';
1279            selectedLayer = null;
1280        function clearHighlights() {{
1281            const highlighted = svg.querySelectorAll('.highlight');
1282            highlighted.forEach(el => el.classList.remove('highlight'));
1283        // Layout change implementation
1284        function changeLayout() {{
1285            const select = document.getElementById('layoutSelect');
1286            const layout = select.value;
1287            console.log(`Switching to ${{layout}} layout`);
1288            // Apply different layout algorithms
1289            switch(layout) {{
1290                case 'hierarchical':
1291                    applyHierarchicalLayout();
1292                    break;
1293                case 'circular':
1294                    applyCircularLayout();
1295                case 'force':
1296                    applyForceDirectedLayout();
1297                case 'grid':
1298                    applyGridLayout();
1299                default:
1300                    applyDefaultLayout();
1301        function applyHierarchicalLayout() {{
1302            const width = svg.viewBox.baseVal.width || 800;
1303            const height = svg.viewBox.baseVal.height || 600;
1304            const margin = 50;
1305                const x = margin + (index % 4) * (width - 2 * margin) / 3;
1306                const y = margin + Math.floor(index / 4) * (height - 2 * margin) / 3;
1307                layer.setAttribute('x', x);
1308                layer.setAttribute('y', y);
1309        function applyCircularLayout() {{
1310            const centerX = (svg.viewBox.baseVal.width || 800) / 2;
1311            const centerY = (svg.viewBox.baseVal.height || 600) / 2;
1312            const radius = Math.min(centerX, centerY) - 100;
1313                const angle = (2 * Math.PI * index) / layers.length;
1314                const x = centerX + radius * Math.cos(angle);
1315                const y = centerY + radius * Math.sin(angle);
1316        function applyForceDirectedLayout() {{
1317            // Simple force-directed positioning
1318                const x = Math.random() * (width - 100) + 50;
1319                const y = Math.random() * (height - 100) + 50;
1320        function applyGridLayout() {{
1321            const cols = Math.ceil(Math.sqrt(layers.length));
1322            const rows = Math.ceil(layers.length / cols);
1323                const col = index % cols;
1324                const row = Math.floor(index / cols);
1325                const x = 50 + col * (width - 100) / cols;
1326                const y = 50 + row * (height - 100) / rows;
1327        function applyDefaultLayout() {{
1328                const x = 50 + (index * 100) % (width - 100);
1329                const y = 100 + Math.floor((index * 100) / (width - 100)) * 80;
1330                layer.setAttribute('x', x);
1331                layer.setAttribute('y', y);
1332            }});
1333        }}
1334
1335        function applyDefaultLayout() {{
1336            const width = svg.viewBox.baseVal.width || 800;
1337            const height = svg.viewBox.baseVal.height || 600;
1338            layers.forEach((layer, index) => {{
1339                const x = 50 + (index * 100) % (width - 100);
1340                const y = 100 + Math.floor((index * 100) / (width - 100)) * 80;
1341                layer.setAttribute('x', x);
1342                layer.setAttribute('y', y);
1343            }});
1344        }}
1345
1346        // Initialize interactive features
1347        document.addEventListener('DOMContentLoaded', function() {{
1348            // Add event listeners to existing SVG elements
1349            showLayerDetails();
1350            // Close info panel when clicking outside
1351            document.addEventListener('click', function(e) {{
1352                if (!infoPanel.contains(e.target) && !e.target.closest('rect, circle, ellipse')) {{
1353                    hideInfo();
1354                    clearHighlights();
1355                }}
1356            }});
1357
1358            // Keyboard shortcuts
1359            document.addEventListener('keydown', function(e) {{
1360                switch(e.key) {{
1361                    case '+':
1362                    case '=':
1363                        zoomIn();
1364                        break;
1365                    case '-':
1366                        zoomOut();
1367                        break;
1368                    case '0':
1369                        resetView();
1370                        break;
1371                    case 'l':
1372                        toggleLabels();
1373                        break;
1374                    case 'c':
1375                        toggleConnections();
1376                        break;
1377                    case 'h':
1378                        highlightPath();
1379                        break;
1380                }}
1381            }});
1382        }});
1383    </script>
1384</body>
1385</html>"#,
1386            algorithm = format_args!("{:?}", layout.algorithm),
1387            layer_count = layout.layer_positions.len(),
1388            connection_count = layout.connections.len(),
1389            svg_content = svg_content
1390        );
1391
1392        Ok(html_content)
1393    }
1394
1395    /// Get the cached layout if available
1396    pub fn get_cached_layout(&self) -> Option<&NetworkLayout> {
1397        self.layout_cache.as_ref()
1398    }
1399
1400    /// Clear the layout cache
1401    pub fn clear_cache(&mut self) {
1402        self.layout_cache = None;
1403    }
1404
1405    /// Update the visualization configuration
1406    pub fn update_config(&mut self, config: VisualizationConfig) {
1407        self.config = config;
1408        self.clear_cache(); // Clear cache when config changes
1409    }
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414    use super::*;
1415    use crate::layers::Dense;
1416    use scirs2_core::random::SeedableRng;
1417    #[test]
1418    fn test_network_visualizer_creation() {
1419        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1420        let mut model = Sequential::<f32>::new();
1421        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1422        let config = VisualizationConfig::default();
1423        let visualizer = NetworkVisualizer::new(model, config);
1424
1425        assert!(visualizer.layout_cache.is_none());
1426    }
1427
1428    #[test]
1429    fn test_layout_algorithm_variants() {
1430        let hierarchical = LayoutAlgorithm::Hierarchical;
1431        let force_directed = LayoutAlgorithm::ForceDirected;
1432        let circular = LayoutAlgorithm::Circular;
1433        let grid = LayoutAlgorithm::Grid;
1434        assert_eq!(hierarchical, LayoutAlgorithm::Hierarchical);
1435        assert_eq!(force_directed, LayoutAlgorithm::ForceDirected);
1436        assert_eq!(circular, LayoutAlgorithm::Circular);
1437        assert_eq!(grid, LayoutAlgorithm::Grid);
1438    }
1439
1440    #[test]
1441    fn test_connection_types() {
1442        let forward = ConnectionType::Forward;
1443        let skip = ConnectionType::Skip;
1444        let attention = ConnectionType::Attention;
1445        let recurrent = ConnectionType::Recurrent;
1446        let custom = ConnectionType::Custom("test".to_string());
1447        assert_eq!(forward, ConnectionType::Forward);
1448        assert_eq!(skip, ConnectionType::Skip);
1449        assert_eq!(attention, ConnectionType::Attention);
1450        assert_eq!(recurrent, ConnectionType::Recurrent);
1451        match custom {
1452            ConnectionType::Custom(name) => assert_eq!(name, "test"),
1453            _ => unreachable!("Expected custom connection type"),
1454        }
1455    }
1456
1457    #[test]
1458    fn test_bounding_box_computation() {
1459        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1460        let mut model = Sequential::<f32>::new();
1461        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1462        let config = VisualizationConfig::default();
1463        let visualizer = NetworkVisualizer::new(model, config);
1464
1465        // Test empty positions
1466        // Test empty positions
1467        let empty_positions = vec![];
1468        let bounds = visualizer.compute_bounds(&empty_positions);
1469        assert_eq!(bounds.min_x, 0.0);
1470        assert_eq!(bounds.min_y, 0.0);
1471        assert_eq!(bounds.max_x, 100.0);
1472        assert_eq!(bounds.max_y, 100.0);
1473    }
1474
1475    #[test]
1476    fn test_point_2d() {
1477        let point = Point2D { x: 10.0, y: 20.0 };
1478
1479        assert_eq!(point.x, 10.0);
1480        assert_eq!(point.y, 20.0);
1481    }
1482
1483    #[test]
1484    fn test_size_2d() {
1485        let size = Size2D {
1486            width: 100.0,
1487            height: 50.0,
1488        };
1489
1490        assert_eq!(size.width, 100.0);
1491        assert_eq!(size.height, 50.0);
1492    }
1493
1494    #[test]
1495    fn test_line_style_variants() {
1496        assert_eq!(LineStyle::Solid, LineStyle::Solid);
1497        assert_eq!(LineStyle::Dashed, LineStyle::Dashed);
1498        assert_eq!(LineStyle::Dotted, LineStyle::Dotted);
1499        assert_eq!(LineStyle::DashDot, LineStyle::DashDot);
1500    }
1501
1502    #[test]
1503    fn test_arrow_style_variants() {
1504        assert_eq!(ArrowStyle::None, ArrowStyle::None);
1505        assert_eq!(ArrowStyle::Simple, ArrowStyle::Simple);
1506        assert_eq!(ArrowStyle::Block, ArrowStyle::Block);
1507        assert_eq!(ArrowStyle::Curved, ArrowStyle::Curved);
1508    }
1509}