1use super::config::{ImageFormat, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8use crate::models::sequential::Sequential;
9use scirs2_core::numeric::{Float, NumAssign};
10use serde::Serialize;
11use std::fmt::Debug;
12use std::fs;
13use std::path::PathBuf;
14#[allow(dead_code)]
16pub struct NetworkVisualizer<F: Float + Debug + scirs2_core::ndarray::ScalarOperand + NumAssign> {
17 model: Sequential<F>,
19 config: VisualizationConfig,
21 layout_cache: Option<NetworkLayout>,
23}
24#[derive(Debug, Clone, Serialize)]
26pub struct NetworkLayout {
27 pub layer_positions: Vec<LayerPosition>,
29 pub connections: Vec<Connection>,
31 pub bounds: BoundingBox,
33 pub algorithm: LayoutAlgorithm,
35}
36
37#[derive(Debug, Clone, Serialize)]
39pub struct LayerPosition {
40 pub name: String,
42 pub layer_type: String,
44 pub position: Point2D,
46 pub size: Size2D,
48 pub io_info: LayerIOInfo,
50 pub visual_props: LayerVisualProps,
52}
53
54#[derive(Debug, Clone, Serialize)]
56pub struct Point2D {
57 pub x: f32,
59 pub y: f32,
61}
62
63#[derive(Debug, Clone, Serialize)]
65pub struct Size2D {
66 pub width: f32,
68 pub height: f32,
70}
71
72#[derive(Debug, Clone, Serialize)]
74pub struct LayerIOInfo {
75 pub inputshape: Vec<usize>,
77 pub outputshape: Vec<usize>,
79 pub parameter_count: usize,
81 pub flops: u64,
83}
84
85#[derive(Debug, Clone, Serialize)]
87pub struct LayerVisualProps {
88 pub fill_color: String,
90 pub border_color: String,
92 pub border_width: f32,
94 pub opacity: f32,
96 pub icon: Option<String>,
98}
99
100#[derive(Debug, Clone, Serialize)]
102pub struct Connection {
103 pub from_layer: usize,
105 pub to_layer: usize,
107 pub connection_type: ConnectionType,
109 pub visual_props: ConnectionVisualProps,
111 pub data_flow: DataFlowInfo,
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize)]
117pub enum ConnectionType {
118 Forward,
120 Skip,
122 Attention,
124 Recurrent,
126 Sequential,
128 Lateral,
130 Custom(String),
132}
133
134#[derive(Debug, Clone, Serialize)]
136pub struct ConnectionVisualProps {
137 pub color: String,
139 pub width: f32,
141 pub style: LineStyle,
143 pub arrow: ArrowStyle,
145 pub opacity: f32,
147}
148
149#[derive(Debug, Clone, PartialEq, Serialize)]
151pub enum LineStyle {
152 Solid,
154 Dashed,
156 Dotted,
158 DashDot,
160}
161
162#[derive(Debug, Clone, PartialEq, Serialize)]
164pub enum ArrowStyle {
165 None,
167 Simple,
169 Block,
171 Curved,
173}
174
175#[derive(Debug, Clone, Serialize)]
177pub struct DataFlowInfo {
178 pub tensorshape: Vec<usize>,
180 pub data_type: String,
182 pub memory_usage: usize,
184 pub batch_size: Option<usize>,
186 pub throughput: Option<ThroughputInfo>,
188}
189
190#[derive(Debug, Clone, Serialize)]
192pub struct ThroughputInfo {
193 pub samples_per_second: f64,
195 pub bytes_per_second: u64,
197 pub latency_ms: f64,
199}
200
201#[derive(Debug, Clone, Serialize)]
203pub struct BoundingBox {
204 pub min_x: f32,
206 pub min_y: f32,
208 pub max_x: f32,
210 pub max_y: f32,
212}
213
214#[derive(Debug, Clone, PartialEq, Serialize)]
216pub enum LayoutAlgorithm {
217 Hierarchical,
219 ForceDirected,
221 Circular,
223 Grid,
225 Custom(String),
227}
228
229#[derive(Debug, Clone)]
231pub struct LayerInfo {
232 pub layer_name: String,
234 pub layer_index: usize,
236 pub layer_type: String,
238}
239
240impl<
242 F: Float
243 + Debug
244 + std::fmt::Display
245 + 'static
246 + scirs2_core::numeric::FromPrimitive
247 + scirs2_core::ndarray::ScalarOperand
248 + Send
249 + Sync
250 + NumAssign,
251 > NetworkVisualizer<F>
252{
253 pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
255 Self {
256 model,
257 config,
258 layout_cache: None,
259 }
260 }
261 pub fn visualize_architecture(&mut self) -> Result<PathBuf> {
263 let layout = self.compute_layout()?;
265 self.layout_cache = Some(layout.clone());
266 match self.config.image_format {
268 ImageFormat::SVG => self.generate_svg_visualization(&layout),
269 ImageFormat::HTML => self.generate_html_visualization(&layout),
270 ImageFormat::JSON => self.generate_json_visualization(&layout),
271 _ => self.generate_svg_visualization(&layout), }
273 }
274
275 fn compute_layout(&self) -> Result<NetworkLayout> {
277 let layer_info = self.analyze_model_structure()?;
279 let algorithm = self.select_layout_algorithm(&layer_info);
281 let (positions, connections) = match algorithm {
283 LayoutAlgorithm::Hierarchical => self.compute_hierarchical_layout(&layer_info)?,
284 LayoutAlgorithm::ForceDirected => self.compute_force_directed_layout(&layer_info)?,
285 LayoutAlgorithm::Circular => self.compute_circular_layout(&layer_info)?,
286 LayoutAlgorithm::Grid => self.compute_grid_layout(&layer_info)?,
287 LayoutAlgorithm::Custom(_) => self.compute_hierarchical_layout(&layer_info)?, };
289 let bounds = self.compute_bounds(&positions);
291 Ok(NetworkLayout {
292 layer_positions: positions,
293 connections,
294 bounds,
295 algorithm,
296 })
297 }
298
299 fn analyze_model_structure(&self) -> Result<Vec<LayerInfo>> {
300 let mut layer_info = Vec::new();
301 let layers = self.model.layers();
303 for (index, layer) in layers.iter().enumerate() {
304 let layer_type = layer.layer_type().to_string();
305 let layer_name = format!("{layer_type}_{index}");
306 layer_info.push(LayerInfo {
307 layer_name,
308 layer_index: index,
309 layer_type,
310 });
311 }
312
313 if layer_info.is_empty() {
315 return Err(NeuralError::InvalidArgument(
316 "Model has no layers".to_string(),
317 ));
318 }
319
320 Ok(layer_info)
321 }
322
323 fn select_layout_algorithm(&self, _layer_info: &[LayerInfo]) -> LayoutAlgorithm {
324 LayoutAlgorithm::Hierarchical
328 }
329
330 fn compute_hierarchical_layout(
331 &self,
332 layer_info: &[LayerInfo],
333 ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
334 if layer_info.is_empty() {
335 return Ok((Vec::new(), Vec::new()));
336 }
337
338 let mut positions = Vec::new();
339 let mut connections = Vec::new();
340 let layer_width = 120.0;
342 let layer_height = 60.0;
343 let vertical_spacing = 100.0;
344 let horizontal_spacing = 150.0;
345 let total_width = (layer_info.len() as f32 - 1.0) * horizontal_spacing + layer_width;
347 let start_x = -total_width / 2.0 + layer_width / 2.0;
348 let start_y = -(layer_info.len() as f32 - 1.0) * vertical_spacing / 2.0;
349 for (i, layer) in layer_info.iter().enumerate() {
351 let x = start_x;
352 let y = start_y + i as f32 * vertical_spacing;
353 let (fill_color, border_color, icon) = match layer.layer_type.as_str() {
355 "Dense" => (
356 "#4CAF50".to_string(),
357 "#2E7D32".to_string(),
358 Some("◯".to_string()),
359 ),
360 "Conv2D" => (
361 "#2196F3".to_string(),
362 "#1565C0".to_string(),
363 Some("⬜".to_string()),
364 ),
365 "Conv1D" => (
366 "#03A9F4".to_string(),
367 "#0277BD".to_string(),
368 Some("▬".to_string()),
369 ),
370 "MaxPool2D" | "AvgPool2D" => (
371 "#FF9800".to_string(),
372 "#E65100".to_string(),
373 Some("▣".to_string()),
374 ),
375 "Dropout" => (
376 "#9C27B0".to_string(),
377 "#6A1B9A".to_string(),
378 Some("×".to_string()),
379 ),
380 "BatchNorm" => (
381 "#607D8B".to_string(),
382 "#37474F".to_string(),
383 Some("∼".to_string()),
384 ),
385 "Activation" => (
386 "#FFC107".to_string(),
387 "#F57C00".to_string(),
388 Some("∘".to_string()),
389 ),
390 "LSTM" => (
391 "#E91E63".to_string(),
392 "#AD1457".to_string(),
393 Some("⟲".to_string()),
394 ),
395 "GRU" => (
396 "#F44336".to_string(),
397 "#C62828".to_string(),
398 Some("⟳".to_string()),
399 ),
400 "Attention" => (
401 "#673AB7".to_string(),
402 "#4527A0".to_string(),
403 Some("◉".to_string()),
404 ),
405 _ => (
406 "#9E9E9E".to_string(),
407 "#424242".to_string(),
408 Some("?".to_string()),
409 ),
410 };
411 let parameter_count = match layer.layer_type.as_str() {
413 "Dense" => 10000, "Conv2D" => 5000,
415 "Conv1D" => 3000,
416 _ => 0,
417 };
418
419 let flops = match layer.layer_type.as_str() {
421 "Dense" => 100000,
422 "Conv2D" => 500000,
423 "Conv1D" => 200000,
424 _ => 1000,
425 };
426
427 let position = LayerPosition {
428 name: layer.layer_name.clone(),
429 layer_type: layer.layer_type.clone(),
430 position: Point2D { x, y },
431 size: Size2D {
432 width: layer_width,
433 height: layer_height,
434 },
435 io_info: LayerIOInfo {
436 inputshape: vec![32, 32, 3], outputshape: vec![32, 32, 3], parameter_count,
439 flops,
440 },
441 visual_props: LayerVisualProps {
442 fill_color,
443 border_color,
444 border_width: 2.0,
445 opacity: 0.9,
446 icon,
447 },
448 };
449
450 positions.push(position);
451 }
452
453 for i in 0..(layer_info.len().saturating_sub(1)) {
455 let connection = Connection {
456 from_layer: i,
457 to_layer: i + 1,
458 connection_type: ConnectionType::Forward,
459 visual_props: ConnectionVisualProps {
460 color: "#666666".to_string(),
461 width: 2.0,
462 style: LineStyle::Solid,
463 arrow: ArrowStyle::Simple,
464 opacity: 0.8,
465 },
466 data_flow: DataFlowInfo {
467 tensorshape: vec![32, 32, 3], data_type: "f32".to_string(),
469 memory_usage: 4096, batch_size: Some(32), throughput: Some(ThroughputInfo {
472 samples_per_second: 1000.0,
473 bytes_per_second: 4096000,
474 latency_ms: 1.0,
475 }),
476 },
477 };
478
479 connections.push(connection);
480 }
481
482 Ok((positions, connections))
483 }
484
485 fn compute_force_directed_layout(
486 &self,
487 layer_info: &[LayerInfo],
488 ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
489 if layer_info.is_empty() {
490 return Ok((Vec::new(), Vec::new()));
491 }
492
493 let mut positions = Vec::new();
494 let mut connections = Vec::new();
495 let area = 800.0 * 600.0; let k = (area / layer_info.len() as f32).sqrt(); let iterations = 100;
499 let cooling_factor = 0.95;
500 let mut temperature = 100.0;
501 let mut node_positions: Vec<Point2D> = (0..layer_info.len())
503 .map(|i| Point2D {
504 x: ((i % 4) as f32 - 1.5) * 100.0, y: ((i / 4) as f32 - 1.5) * 100.0,
506 })
507 .collect();
508 for _iteration in 0..iterations {
510 let mut forces: Vec<Point2D> = vec![Point2D { x: 0.0, y: 0.0 }; layer_info.len()];
511 for i in 0..layer_info.len() {
513 for j in (i + 1)..layer_info.len() {
514 let dx = node_positions[i].x - node_positions[j].x;
515 let dy = node_positions[i].y - node_positions[j].y;
516 let distance = (dx * dx + dy * dy).sqrt().max(1.0);
517 let repulsive_force = k * k / distance;
518 let fx = repulsive_force * dx / distance;
519 let fy = repulsive_force * dy / distance;
520 forces[i].x += fx;
521 forces[i].y += fy;
522 forces[j].x -= fx;
523 forces[j].y -= fy;
524 }
525 }
526 for i in 0..(layer_info.len() - 1) {
528 let dx = node_positions[i].x - node_positions[i + 1].x;
529 let dy = node_positions[i].y - node_positions[i + 1].y;
530 let distance = (dx * dx + dy * dy).sqrt().max(1.0);
531 let attractive_force = distance * distance / k;
532 let fx = attractive_force * dx / distance;
533 let fy = attractive_force * dy / distance;
534 forces[i].x -= fx;
535 forces[i].y -= fy;
536 forces[i + 1].x += fx;
537 forces[i + 1].y += fy;
538 }
539
540 for i in 0..layer_info.len() {
542 let force_magnitude =
543 (forces[i].x * forces[i].x + forces[i].y * forces[i].y).sqrt();
544 if force_magnitude > 0.0 {
545 let displacement = temperature.min(force_magnitude);
546 node_positions[i].x += forces[i].x / force_magnitude * displacement;
547 node_positions[i].y += forces[i].y / force_magnitude * displacement;
548 }
549 }
550
551 temperature *= cooling_factor;
552 }
553
554 for (i, layer) in layer_info.iter().enumerate() {
556 let position = LayerPosition {
557 name: layer.layer_name.clone(),
558 layer_type: layer.layer_type.clone(),
559 position: node_positions[i].clone(),
560 size: Size2D {
561 width: 120.0,
562 height: 60.0,
563 },
564 io_info: LayerIOInfo {
565 inputshape: vec![1, 32], outputshape: vec![1, 32],
567 parameter_count: 1024,
568 flops: 2048,
569 },
570 visual_props: LayerVisualProps {
571 fill_color: "#8BC34A".to_string(),
572 border_color: "#558B2F".to_string(),
573 border_width: 2.0,
574 opacity: 0.9,
575 icon: Some("▢".to_string()),
576 },
577 };
578 positions.push(position);
579 }
580 for i in 0..(layer_info.len().saturating_sub(1)) {
582 let connection = Connection {
583 from_layer: i,
584 to_layer: i + 1,
585 connection_type: ConnectionType::Sequential,
586 visual_props: ConnectionVisualProps {
587 color: "#666666".to_string(),
588 width: 2.0,
589 style: LineStyle::Solid,
590 arrow: ArrowStyle::Simple,
591 opacity: 0.7,
592 },
593 data_flow: DataFlowInfo {
594 tensorshape: vec![1, 32],
595 data_type: "float32".to_string(),
596 memory_usage: 128, batch_size: Some(1),
598 throughput: None,
599 },
600 };
601 connections.push(connection);
602 }
603
604 Ok((positions, connections))
605 }
606
607 fn compute_circular_layout(
608 &self,
609 layer_info: &[LayerInfo],
610 ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
611 if layer_info.is_empty() {
612 return Ok((Vec::new(), Vec::new()));
613 }
614
615 let mut positions = Vec::new();
616 let mut connections = Vec::new();
617 let radius = if layer_info.len() == 1 {
619 50.0
620 } else {
621 let circumference = layer_info.len() as f32 * 150.0; circumference / (2.0 * std::f32::consts::PI)
624 };
625
626 let center_x = 0.0;
627 let center_y = 0.0;
628
629 for (i, layer) in layer_info.iter().enumerate() {
631 let angle = if layer_info.len() == 1 {
632 0.0
633 } else {
634 2.0 * std::f32::consts::PI * i as f32 / layer_info.len() as f32
635 };
636
637 let x = center_x + radius * angle.cos();
638 let y = center_y + radius * angle.sin();
639
640 let position = LayerPosition {
641 name: layer.layer_name.clone(),
642 layer_type: layer.layer_type.clone(),
643 position: Point2D { x, y },
644 size: Size2D {
645 width: 120.0,
646 height: 60.0,
647 },
648 io_info: LayerIOInfo {
649 inputshape: vec![1, 32],
650 outputshape: vec![1, 32],
651 parameter_count: 1024,
652 flops: 2048,
653 },
654 visual_props: LayerVisualProps {
655 fill_color: "#FF9800".to_string(),
656 border_color: "#E65100".to_string(),
657 border_width: 2.0,
658 opacity: 0.9,
659 icon: Some("⭕".to_string()),
660 },
661 };
662 positions.push(position);
663 }
664
665 for i in 0..(layer_info.len().saturating_sub(1)) {
667 let connection = Connection {
668 from_layer: i,
669 to_layer: i + 1,
670 connection_type: ConnectionType::Forward,
671 visual_props: ConnectionVisualProps {
672 color: "#666666".to_string(),
673 width: 2.0,
674 style: LineStyle::Solid,
675 arrow: ArrowStyle::Simple,
676 opacity: 0.7,
677 },
678 data_flow: DataFlowInfo {
679 tensorshape: vec![1, 32],
680 data_type: "float32".to_string(),
681 memory_usage: 128,
682 batch_size: Some(1),
683 throughput: None,
684 },
685 };
686 connections.push(connection);
687 }
688 if layer_info.len() > 2 {
690 let connection = Connection {
691 from_layer: layer_info.len() - 1,
692 to_layer: 0,
693 connection_type: ConnectionType::Recurrent,
694 visual_props: ConnectionVisualProps {
695 color: "#999999".to_string(),
696 width: 1.5,
697 style: LineStyle::Dashed,
698 arrow: ArrowStyle::Simple,
699 opacity: 0.5,
700 },
701 data_flow: DataFlowInfo {
702 tensorshape: vec![1, 32],
703 data_type: "float32".to_string(),
704 memory_usage: 128,
705 batch_size: Some(1),
706 throughput: None,
707 },
708 };
709 connections.push(connection);
710 }
711
712 Ok((positions, connections))
713 }
714
715 fn compute_grid_layout(
716 &self,
717 layer_info: &[LayerInfo],
718 ) -> Result<(Vec<LayerPosition>, Vec<Connection>)> {
719 if layer_info.is_empty() {
720 return Ok((Vec::new(), Vec::new()));
721 }
722
723 let mut positions = Vec::new();
724 let mut connections = Vec::new();
725 let cell_width = 180.0;
727 let cell_height = 120.0;
728 let margin = 20.0;
729 let total_layers = layer_info.len();
731 let grid_cols = (total_layers as f32).sqrt().ceil() as usize;
732 let grid_rows = (total_layers as f32 / grid_cols as f32).ceil() as usize;
733 let total_width = grid_cols as f32 * cell_width;
735 let total_height = grid_rows as f32 * cell_height;
736 let start_x = -total_width / 2.0 + cell_width / 2.0;
737 let start_y = -total_height / 2.0 + cell_height / 2.0;
738 for (i, layer) in layer_info.iter().enumerate() {
740 let col = i % grid_cols;
741 let row = i / grid_cols;
742 let x = start_x + col as f32 * cell_width;
743 let y = start_y + row as f32 * cell_height;
744
745 let position = LayerPosition {
746 name: layer.layer_name.clone(),
747 layer_type: layer.layer_type.clone(),
748 position: Point2D { x, y },
749 size: Size2D {
750 width: cell_width - margin,
751 height: cell_height - margin,
752 },
753 io_info: LayerIOInfo {
754 inputshape: vec![1, 32],
755 outputshape: vec![1, 32],
756 parameter_count: 1024,
757 flops: 2048,
758 },
759 visual_props: LayerVisualProps {
760 fill_color: "#2196F3".to_string(),
761 border_color: "#1565C0".to_string(),
762 border_width: 2.0,
763 opacity: 0.9,
764 icon: Some("⬜".to_string()),
765 },
766 };
767 positions.push(position);
768 }
769
770 for i in 0..(layer_info.len().saturating_sub(1)) {
772 let from_col = i % grid_cols;
773 let from_row = i / grid_cols;
774 let to_col = (i + 1) % grid_cols;
775 let to_row = (i + 1) / grid_cols;
776
777 let (color, style, width) = if from_row == to_row {
779 ("#4CAF50".to_string(), LineStyle::Solid, 2.5)
781 } else if from_col == to_col {
782 ("#2196F3".to_string(), LineStyle::Solid, 2.5)
784 } else {
785 ("#FF9800".to_string(), LineStyle::Dashed, 2.0)
787 };
788
789 let connection = Connection {
790 from_layer: i,
791 to_layer: i + 1,
792 connection_type: ConnectionType::Forward,
793 visual_props: ConnectionVisualProps {
794 color,
795 width,
796 style,
797 arrow: ArrowStyle::Simple,
798 opacity: 0.7,
799 },
800 data_flow: DataFlowInfo {
801 tensorshape: vec![1, 32],
802 data_type: "float32".to_string(),
803 memory_usage: 128,
804 batch_size: Some(1),
805 throughput: None,
806 },
807 };
808 connections.push(connection);
809 }
810
811 if grid_rows > 1 {
814 for row in 0..grid_rows {
815 for col in 0..(grid_cols - 1) {
816 let from_idx = row * grid_cols + col;
817 let to_idx = row * grid_cols + col + 1;
818 if from_idx < total_layers && to_idx < total_layers && from_idx + 1 != to_idx {
819 let connection = Connection {
820 from_layer: from_idx,
821 to_layer: to_idx,
822 connection_type: ConnectionType::Lateral,
823 data_flow: DataFlowInfo {
824 tensorshape: vec![1, 16],
825 data_type: "float32".to_string(),
826 memory_usage: 64, batch_size: Some(1),
828 throughput: None,
829 },
830 visual_props: ConnectionVisualProps {
831 color: "#9E9E9E".to_string(),
832 width: 1.0,
833 style: LineStyle::Dotted,
834 arrow: ArrowStyle::None,
835 opacity: 0.4,
836 },
837 };
838 connections.push(connection);
839 }
840 }
841 }
842 }
843
844 Ok((positions, connections))
845 }
846
847 fn compute_bounds(&self, positions: &[LayerPosition]) -> BoundingBox {
848 if positions.is_empty() {
849 return BoundingBox {
850 min_x: 0.0,
851 min_y: 0.0,
852 max_x: 100.0,
853 max_y: 100.0,
854 };
855 }
856
857 let mut min_x = f32::INFINITY;
858 let mut min_y = f32::INFINITY;
859 let mut max_x = f32::NEG_INFINITY;
860 let mut max_y = f32::NEG_INFINITY;
861 for pos in positions {
862 min_x = min_x.min(pos.position.x - pos.size.width / 2.0);
863 min_y = min_y.min(pos.position.y - pos.size.height / 2.0);
864 max_x = max_x.max(pos.position.x + pos.size.width / 2.0);
865 max_y = max_y.max(pos.position.y + pos.size.height / 2.0);
866 }
867
868 BoundingBox {
869 min_x,
870 min_y,
871 max_x,
872 max_y,
873 }
874 }
875
876 fn generate_svg_visualization(&self, layout: &NetworkLayout) -> Result<PathBuf> {
877 let output_path = self.config.output_dir.join("network_architecture.svg");
878 let svg_content = self.create_svg_content(layout)?;
880 fs::write(&output_path, svg_content)
882 .map_err(|e| NeuralError::IOError(format!("Failed to write SVG file: {e}")))?;
883
884 Ok(output_path)
885 }
886
887 fn generate_html_visualization(&self, layout: &NetworkLayout) -> Result<PathBuf> {
888 let output_path = self.config.output_dir.join("network_architecture.html");
889 let html_content = self.create_html_content(layout)?;
891
892 fs::write(&output_path, html_content)
893 .map_err(|e| NeuralError::IOError(format!("Failed to write HTML file: {e}")))?;
894
895 Ok(output_path)
896 }
897
898 fn generate_json_visualization(&self, layout: &NetworkLayout) -> Result<PathBuf> {
899 let output_path = self.config.output_dir.join("network_architecture.json");
900 let json_content = serde_json::to_string_pretty(&layout).map_err(|e| {
902 NeuralError::SerializationError(format!("Failed to serialize layout: {e}"))
903 })?;
904
905 fs::write(&output_path, json_content)
906 .map_err(|e| NeuralError::IOError(format!("Failed to write JSON file: {e}")))?;
907
908 Ok(output_path)
909 }
910
911 fn create_svg_content(&self, layout: &NetworkLayout) -> Result<String> {
912 let bounds = &layout.bounds;
913 let margin = 50.0;
914 let svg_width = (bounds.max_x - bounds.min_x + 2.0 * margin) as u32;
916 let svg_height = (bounds.max_y - bounds.min_y + 2.0 * margin) as u32;
917 let viewbox_x = bounds.min_x - margin;
919 let viewbox_y = bounds.min_y - margin;
920 let viewbox_width = bounds.max_x - bounds.min_x + 2.0 * margin;
921 let viewbox_height = bounds.max_y - bounds.min_y + 2.0 * margin;
922 let mut svg = format!(
923 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
924<svg width=\"{}\" height=\"{}\" viewBox=\"{} {} {} {}\" xmlns=\"http://www.w3.org/2000/svg\">\n\
925 <title>Neural Network Architecture</title>\n\
926 <defs>\n\
927 <style>\n\
928 .layer-rect {{ stroke-width: 2; }}\n\
929 .connection {{ fill: none; marker-end: url(#arrowhead); }}\n\
930 .layer-text {{ font-family: Arial, sans-serif; font-size: 11px; text-anchor: middle; fill: white; font-weight: bold; }}\n\
931 .layer-info {{ font-family: Arial, sans-serif; font-size: 9px; text-anchor: middle; fill: #333; }}\n\
932 .layer-icon {{ font-family: Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: white; font-weight: bold; }}\n\
933 </style>\n\
934 <marker id=\"arrowhead\" markerWidth=\"10\" markerHeight=\"7\" refX=\"10\" refY=\"3.5\" orient=\"auto\">\n\
935 <polygon points=\"0 0, 10 3.5, 0 7\" fill=\"#{}\"/>\n\
936 </marker>\n\
937 </defs>\n\
938 \n\
939 <!-- Background -->\n\
940 <rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"#{}\" stroke=\"#{}\"/>\n\
941 \n",
942 svg_width, svg_height, viewbox_x, viewbox_y, viewbox_width, viewbox_height,
943 "666666",
944 viewbox_x, viewbox_y, viewbox_width, viewbox_height, "f8f9fa", "dee2e6"
945 );
946 for connection in &layout.connections {
948 if connection.from_layer < layout.layer_positions.len()
949 && connection.to_layer < layout.layer_positions.len()
950 {
951 let from_pos = &layout.layer_positions[connection.from_layer];
952 let to_pos = &layout.layer_positions[connection.to_layer];
953 let x1 = from_pos.position.x;
955 let y1 = from_pos.position.y + from_pos.size.height / 2.0;
956 let x2 = to_pos.position.x;
957 let y2 = to_pos.position.y - to_pos.size.height / 2.0;
958 let stroke_width = connection.visual_props.width;
959 let stroke_color = &connection.visual_props.color;
960 let opacity = connection.visual_props.opacity;
961 svg.push_str(&format!(
962 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="{}" opacity="{}" class="connection"/>
963"#,
964 x1, y1, x2, y2, stroke_color, stroke_width, opacity
965 ));
966 }
967 }
968
969 for (i, layer_pos) in layout.layer_positions.iter().enumerate() {
971 let x = layer_pos.position.x - layer_pos.size.width / 2.0;
972 let y = layer_pos.position.y - layer_pos.size.height / 2.0;
973 let width = layer_pos.size.width;
974 let height = layer_pos.size.height;
975 let fill_color = &layer_pos.visual_props.fill_color;
976 let border_color = &layer_pos.visual_props.border_color;
977 let border_width = layer_pos.visual_props.border_width;
978 let opacity = layer_pos.visual_props.opacity;
979 svg.push_str(&format!(
981 r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="{}" stroke-width="{}" opacity="{}" rx="5" class="layer-rect"/>
982"#,
983 x, y, width, height, fill_color, border_color, border_width, opacity
984 ));
985
986 if let Some(ref icon) = layer_pos.visual_props.icon {
988 svg.push_str(&format!(
989 r#" <text x="{}" y="{}" class="layer-icon">{}</text>
990"#,
991 layer_pos.position.x,
992 layer_pos.position.y - 5.0,
993 icon
994 ));
995 }
996
997 svg.push_str(&format!(
999 r#" <text x="{}" y="{}" class="layer-text">{}</text>
1000"#,
1001 layer_pos.position.x,
1002 layer_pos.position.y + 8.0,
1003 layer_pos.layer_type
1004 ));
1005 let param_text = if layer_pos.io_info.parameter_count > 0 {
1007 format!("{}K params", layer_pos.io_info.parameter_count / 1000)
1008 } else {
1009 "No params".to_string()
1010 };
1011
1012 svg.push_str(&format!(
1013 r#" <text x="{}" y="{}" class="layer-info">{}</text>
1014"#,
1015 layer_pos.position.x,
1016 y + height + 15.0,
1017 param_text
1018 ));
1019
1020 svg.push_str(&format!(
1022 r#" <text x="{}" y="{}" class="layer-info">Layer {}</text>
1023"#,
1024 layer_pos.position.x,
1025 y - 10.0,
1026 i
1027 ));
1028 }
1029 let legend_x = viewbox_x + 10.0;
1031 let legend_y = viewbox_y + viewbox_height - 100.0;
1032 svg.push_str(&format!(
1033 " <!-- Legend -->\n\
1034 <rect x=\"{}\" y=\"{}\" width=\"200\" height=\"80\" fill=\"white\" stroke=\"#{}\" stroke-width=\"1\" opacity=\"0.9\" rx=\"5\"/>\n\
1035 <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"12\" font-weight=\"bold\" fill=\"#333\">Legend</text>\n\
1036 <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">◯ Dense Layer</text>\n\
1037 <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">⬜ Conv2D Layer</text>\n\
1038 <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">× Dropout Layer</text>\n\
1039 <text x=\"{}\" y=\"{}\" font-family=\"Arial\" font-size=\"10\" fill=\"#666\">∼ BatchNorm Layer</text>\n",
1040 legend_x, legend_y, "ccc",
1041 legend_x + 10.0, legend_y + 15.0,
1042 legend_x + 10.0, legend_y + 30.0,
1043 legend_x + 10.0, legend_y + 45.0,
1044 legend_x + 10.0, legend_y + 60.0,
1045 legend_x + 10.0, legend_y + 75.0
1046 ));
1047
1048 svg.push_str("</svg>");
1049
1050 Ok(svg)
1051 }
1052
1053 fn create_html_content(&self, layout: &NetworkLayout) -> Result<String> {
1054 let svg_content = self.create_svg_content(layout)?;
1056 let html_content = format!(
1058 r#"<!DOCTYPE html>
1059<html lang="en">
1060<head>
1061 <meta charset="UTF-8">
1062 <meta name="viewport" content="width=device-width, initial-scale=1.0">
1063 <title>Interactive Neural Network Architecture</title>
1064 <style>
1065 body {{
1066 font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
1067 margin: 0;
1068 padding: 20px;
1069 background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
1070 color: #333;
1071 }}
1072
1073 .header {{
1074 text-align: center;
1075 margin-bottom: 30px;
1076 background: white;
1077 border-radius: 10px;
1078 box-shadow: 0 4px 6px rgba(0,0,0,0.1);
1079 .controls {{
1080 margin-bottom: 20px;
1081 .control-group {{
1082 display: inline-block;
1083 margin-right: 20px;
1084 vertical-align: top;
1085 .control-group label {{
1086 display: block;
1087 font-weight: bold;
1088 margin-bottom: 5px;
1089 color: #555;
1090 button {{
1091 padding: 10px 20px;
1092 margin: 5px;
1093 border: none;
1094 border-radius: 5px;
1095 cursor: pointer;
1096 font-size: 14px;
1097 transition: all 0.3s ease;
1098 background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1099 color: white;
1100 button:hover {{
1101 transform: translateY(-2px);
1102 box-shadow: 0 5px 15px rgba(0,0,0,0.2);
1103 button:active {{
1104 transform: translateY(0);
1105 select {{
1106 padding: 8px 12px;
1107 border: 1px solid #ddd;
1108 #visualization {{
1109 overflow: hidden;
1110 position: relative;
1111 #network-svg {{
1112 width: 100%;
1113 height: 700px;
1114 transition: transform 0.3s ease;
1115 .layer-node {{
1116 .layer-node:hover {{
1117 stroke-width: 3;
1118 filter: drop-shadow(0 4px 8px rgba(0,0,0,0.3));
1119 .connection-line {{
1120 .connection-line:hover {{
1121 stroke-width: 4;
1122 opacity: 1;
1123 .info-panel {{
1124 position: absolute;
1125 top: 10px;
1126 right: 10px;
1127 background: rgba(255,255,255,0.95);
1128 padding: 15px;
1129 border-radius: 8px;
1130 max-width: 300px;
1131 display: none;
1132 .info-panel h3 {{
1133 margin: 0 0 10px 0;
1134 color: #444;
1135 .info-panel p {{
1136 margin: 5px 0;
1137 font-size: 13px;
1138 color: #666;
1139 .layout-controls {{
1140 margin-bottom: 10px;
1141 .hidden {{
1142 .highlight {{
1143 stroke: #ff6b6b !important;
1144 stroke-width: 4 !important;
1145 filter: drop-shadow(0 0 10px #ff6b6b);
1146 </style>
1147</head>
1148<body>
1149 <div class="header">
1150 <h1>Interactive Neural Network Architecture</h1>
1151 <p>Algorithm: {algorithm} | Layers: {layer_count} | Connections: {connection_count}</p>
1152 </div>
1153
1154 <div class="controls">
1155 <div class="control-group">
1156 <label>Zoom Controls:</label>
1157 <button onclick="zoomIn()">🔍+ Zoom In</button>
1158 <button onclick="zoomOut()">🔍- Zoom Out</button>
1159 <button onclick="resetView()">🎯 Reset View</button>
1160 </div>
1161 <label>Display Options:</label>
1162 <button onclick="toggleLabels()">🏷️ Toggle Labels</button>
1163 <button onclick="toggleConnections()">🔗 Toggle Connections</button>
1164 <button onclick="highlightPath()">⚡ Highlight Data Flow</button>
1165 <label>Layout Algorithm:</label>
1166 <select id="layoutSelect" onchange="changeLayout()">
1167 <option value="hierarchical">📊 Hierarchical</option>
1168 <option value="force-directed">🌟 Force-Directed</option>
1169 <option value="circular">⭕ Circular</option>
1170 <option value="grid">⬜ Grid</option>
1171 </select>
1172 <label>Animation:</label>
1173 <button onclick="animateDataFlow()">🎬 Animate Flow</button>
1174 <button onclick="showLayerDetails()">📋 Layer Details</button>
1175 <div id="visualization">
1176 <div id="network-svg-container">
1177 {svg_content}
1178 <div id="info-panel" class="info-panel">
1179 <h3 id="info-title">Layer Information</h3>
1180 <p><strong>Type:</strong> <span id="info-type">-</span></p>
1181 <p><strong>Input Shape:</strong> <span id="info-input">-</span></p>
1182 <p><strong>Output Shape:</strong> <span id="info-output">-</span></p>
1183 <p><strong>Parameters:</strong> <span id="info-params">-</span></p>
1184 <p><strong>FLOPs:</strong> <span id="info-flops">-</span></p>
1185 <script>
1186 // Global state
1187 let currentZoom = 1.0;
1188 let showLabels = true;
1189 let showConnections = true;
1190 let selectedLayer = null;
1191 let animationRunning = false;
1192 // SVG manipulation
1193 const svg = document.querySelector('#network-svg-container svg');
1194 const infoPanel = document.getElementById('info-panel');
1195 // Zoom functions
1196 function zoomIn() {{
1197 currentZoom = Math.min(currentZoom * 1.2, 3.0);
1198 updateZoom();
1199 function zoomOut() {{
1200 currentZoom = Math.max(currentZoom / 1.2, 0.3);
1201 function resetView() {{
1202 currentZoom = 1.0;
1203 clearHighlights();
1204 hideInfo();
1205 function updateZoom() {{
1206 if (svg) {{
1207 svg.style.transform = `scale(${{currentZoom}})`;
1208 }}
1209 // Label toggle
1210 function toggleLabels() {{
1211 showLabels = !showLabels;
1212 const labels = svg.querySelectorAll('text');
1213 labels.forEach(label => {{
1214 label.style.display = showLabels ? 'block' : 'none';
1215 }});
1216 // Connection toggle
1217 function toggleConnections() {{
1218 showConnections = !showConnections;
1219 const connections = svg.querySelectorAll('.connection-line, line[stroke]');
1220 connections.forEach(conn => {{
1221 conn.style.display = showConnections ? 'block' : 'none';
1222 // Highlight data flow path
1223 function highlightPath() {{
1224 const layers = svg.querySelectorAll('rect, circle, ellipse');
1225 const connections = svg.querySelectorAll('line[stroke], path[stroke]');
1226
1227 // Sequential highlighting with delay
1228 layers.forEach((layer, index) => {{
1229 setTimeout(() => {{
1230 layer.classList.add('highlight');
1231 setTimeout(() => layer.classList.remove('highlight'), 1000);
1232 }}, index * 200);
1233 connections.forEach((conn, index) => {{
1234 conn.classList.add('highlight');
1235 setTimeout(() => conn.classList.remove('highlight'), 1000);
1236 }}, index * 200 + 100);
1237 // Animate data flow
1238 function animateDataFlow() {{
1239 if (animationRunning) return;
1240 animationRunning = true;
1241 conn.style.strokeDasharray = '10,5';
1242 conn.style.strokeDashoffset = '0';
1243 conn.style.animation = 'flow 2s linear infinite';
1244 }}, index * 100);
1245 // Add CSS animation dynamically
1246 const style = document.createElement('style');
1247 style.textContent = `
1248 @keyframes flow {{
1249 to {{ stroke-dashoffset: -15; }}
1250 }}
1251 `;
1252 document.head.appendChild(style);
1253 setTimeout(() => {{
1254 connections.forEach(conn => {{
1255 conn.style.animation = '';
1256 conn.style.strokeDasharray = '';
1257 conn.style.strokeDashoffset = '';
1258 }});
1259 animationRunning = false;
1260 }}, 5000);
1261 // Layer details
1262 function showLayerDetails() {{
1263 layer.addEventListener('click', () => showLayerInfo(layer, index));
1264 layer.style.cursor = 'pointer';
1265 function showLayerInfo(layer, index) {{
1266 selectedLayer = layer;
1267 // Highlight selected layer
1268 layer.classList.add('highlight');
1269 // Show info panel with layer details
1270 document.getElementById('info-title').textContent = `Layer ${{index + 1}}`;
1271 document.getElementById('info-type').textContent = layer.getAttribute('data-type') || 'Unknown';
1272 document.getElementById('info-input').textContent = layer.getAttribute('data-input') || '[1, 32]';
1273 document.getElementById('info-output').textContent = layer.getAttribute('data-output') || '[1, 32]';
1274 document.getElementById('info-params').textContent = layer.getAttribute('data-params') || '1,024';
1275 document.getElementById('info-flops').textContent = layer.getAttribute('data-flops') || '2,048';
1276 infoPanel.style.display = 'block';
1277 function hideInfo() {{
1278 infoPanel.style.display = 'none';
1279 selectedLayer = null;
1280 function clearHighlights() {{
1281 const highlighted = svg.querySelectorAll('.highlight');
1282 highlighted.forEach(el => el.classList.remove('highlight'));
1283 // Layout change implementation
1284 function changeLayout() {{
1285 const select = document.getElementById('layoutSelect');
1286 const layout = select.value;
1287 console.log(`Switching to ${{layout}} layout`);
1288 // Apply different layout algorithms
1289 switch(layout) {{
1290 case 'hierarchical':
1291 applyHierarchicalLayout();
1292 break;
1293 case 'circular':
1294 applyCircularLayout();
1295 case 'force':
1296 applyForceDirectedLayout();
1297 case 'grid':
1298 applyGridLayout();
1299 default:
1300 applyDefaultLayout();
1301 function applyHierarchicalLayout() {{
1302 const width = svg.viewBox.baseVal.width || 800;
1303 const height = svg.viewBox.baseVal.height || 600;
1304 const margin = 50;
1305 const x = margin + (index % 4) * (width - 2 * margin) / 3;
1306 const y = margin + Math.floor(index / 4) * (height - 2 * margin) / 3;
1307 layer.setAttribute('x', x);
1308 layer.setAttribute('y', y);
1309 function applyCircularLayout() {{
1310 const centerX = (svg.viewBox.baseVal.width || 800) / 2;
1311 const centerY = (svg.viewBox.baseVal.height || 600) / 2;
1312 const radius = Math.min(centerX, centerY) - 100;
1313 const angle = (2 * Math.PI * index) / layers.length;
1314 const x = centerX + radius * Math.cos(angle);
1315 const y = centerY + radius * Math.sin(angle);
1316 function applyForceDirectedLayout() {{
1317 // Simple force-directed positioning
1318 const x = Math.random() * (width - 100) + 50;
1319 const y = Math.random() * (height - 100) + 50;
1320 function applyGridLayout() {{
1321 const cols = Math.ceil(Math.sqrt(layers.length));
1322 const rows = Math.ceil(layers.length / cols);
1323 const col = index % cols;
1324 const row = Math.floor(index / cols);
1325 const x = 50 + col * (width - 100) / cols;
1326 const y = 50 + row * (height - 100) / rows;
1327 function applyDefaultLayout() {{
1328 const x = 50 + (index * 100) % (width - 100);
1329 const y = 100 + Math.floor((index * 100) / (width - 100)) * 80;
1330 layer.setAttribute('x', x);
1331 layer.setAttribute('y', y);
1332 }});
1333 }}
1334
1335 function applyDefaultLayout() {{
1336 const width = svg.viewBox.baseVal.width || 800;
1337 const height = svg.viewBox.baseVal.height || 600;
1338 layers.forEach((layer, index) => {{
1339 const x = 50 + (index * 100) % (width - 100);
1340 const y = 100 + Math.floor((index * 100) / (width - 100)) * 80;
1341 layer.setAttribute('x', x);
1342 layer.setAttribute('y', y);
1343 }});
1344 }}
1345
1346 // Initialize interactive features
1347 document.addEventListener('DOMContentLoaded', function() {{
1348 // Add event listeners to existing SVG elements
1349 showLayerDetails();
1350 // Close info panel when clicking outside
1351 document.addEventListener('click', function(e) {{
1352 if (!infoPanel.contains(e.target) && !e.target.closest('rect, circle, ellipse')) {{
1353 hideInfo();
1354 clearHighlights();
1355 }}
1356 }});
1357
1358 // Keyboard shortcuts
1359 document.addEventListener('keydown', function(e) {{
1360 switch(e.key) {{
1361 case '+':
1362 case '=':
1363 zoomIn();
1364 break;
1365 case '-':
1366 zoomOut();
1367 break;
1368 case '0':
1369 resetView();
1370 break;
1371 case 'l':
1372 toggleLabels();
1373 break;
1374 case 'c':
1375 toggleConnections();
1376 break;
1377 case 'h':
1378 highlightPath();
1379 break;
1380 }}
1381 }});
1382 }});
1383 </script>
1384</body>
1385</html>"#,
1386 algorithm = format_args!("{:?}", layout.algorithm),
1387 layer_count = layout.layer_positions.len(),
1388 connection_count = layout.connections.len(),
1389 svg_content = svg_content
1390 );
1391
1392 Ok(html_content)
1393 }
1394
1395 pub fn get_cached_layout(&self) -> Option<&NetworkLayout> {
1397 self.layout_cache.as_ref()
1398 }
1399
1400 pub fn clear_cache(&mut self) {
1402 self.layout_cache = None;
1403 }
1404
1405 pub fn update_config(&mut self, config: VisualizationConfig) {
1407 self.config = config;
1408 self.clear_cache(); }
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414 use super::*;
1415 use crate::layers::Dense;
1416 use scirs2_core::random::SeedableRng;
1417 #[test]
1418 fn test_network_visualizer_creation() {
1419 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1420 let mut model = Sequential::<f32>::new();
1421 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1422 let config = VisualizationConfig::default();
1423 let visualizer = NetworkVisualizer::new(model, config);
1424
1425 assert!(visualizer.layout_cache.is_none());
1426 }
1427
1428 #[test]
1429 fn test_layout_algorithm_variants() {
1430 let hierarchical = LayoutAlgorithm::Hierarchical;
1431 let force_directed = LayoutAlgorithm::ForceDirected;
1432 let circular = LayoutAlgorithm::Circular;
1433 let grid = LayoutAlgorithm::Grid;
1434 assert_eq!(hierarchical, LayoutAlgorithm::Hierarchical);
1435 assert_eq!(force_directed, LayoutAlgorithm::ForceDirected);
1436 assert_eq!(circular, LayoutAlgorithm::Circular);
1437 assert_eq!(grid, LayoutAlgorithm::Grid);
1438 }
1439
1440 #[test]
1441 fn test_connection_types() {
1442 let forward = ConnectionType::Forward;
1443 let skip = ConnectionType::Skip;
1444 let attention = ConnectionType::Attention;
1445 let recurrent = ConnectionType::Recurrent;
1446 let custom = ConnectionType::Custom("test".to_string());
1447 assert_eq!(forward, ConnectionType::Forward);
1448 assert_eq!(skip, ConnectionType::Skip);
1449 assert_eq!(attention, ConnectionType::Attention);
1450 assert_eq!(recurrent, ConnectionType::Recurrent);
1451 match custom {
1452 ConnectionType::Custom(name) => assert_eq!(name, "test"),
1453 _ => unreachable!("Expected custom connection type"),
1454 }
1455 }
1456
1457 #[test]
1458 fn test_bounding_box_computation() {
1459 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1460 let mut model = Sequential::<f32>::new();
1461 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1462 let config = VisualizationConfig::default();
1463 let visualizer = NetworkVisualizer::new(model, config);
1464
1465 let empty_positions = vec![];
1468 let bounds = visualizer.compute_bounds(&empty_positions);
1469 assert_eq!(bounds.min_x, 0.0);
1470 assert_eq!(bounds.min_y, 0.0);
1471 assert_eq!(bounds.max_x, 100.0);
1472 assert_eq!(bounds.max_y, 100.0);
1473 }
1474
1475 #[test]
1476 fn test_point_2d() {
1477 let point = Point2D { x: 10.0, y: 20.0 };
1478
1479 assert_eq!(point.x, 10.0);
1480 assert_eq!(point.y, 20.0);
1481 }
1482
1483 #[test]
1484 fn test_size_2d() {
1485 let size = Size2D {
1486 width: 100.0,
1487 height: 50.0,
1488 };
1489
1490 assert_eq!(size.width, 100.0);
1491 assert_eq!(size.height, 50.0);
1492 }
1493
1494 #[test]
1495 fn test_line_style_variants() {
1496 assert_eq!(LineStyle::Solid, LineStyle::Solid);
1497 assert_eq!(LineStyle::Dashed, LineStyle::Dashed);
1498 assert_eq!(LineStyle::Dotted, LineStyle::Dotted);
1499 assert_eq!(LineStyle::DashDot, LineStyle::DashDot);
1500 }
1501
1502 #[test]
1503 fn test_arrow_style_variants() {
1504 assert_eq!(ArrowStyle::None, ArrowStyle::None);
1505 assert_eq!(ArrowStyle::Simple, ArrowStyle::Simple);
1506 assert_eq!(ArrowStyle::Block, ArrowStyle::Block);
1507 assert_eq!(ArrowStyle::Curved, ArrowStyle::Curved);
1508 }
1509}