scirs2_neural/utils/
model_viz.rs

1//! Model architecture visualization utilities
2//!
3//! This module provides utilities for visualizing the architecture
4//! of neural network models.
5
6use crate::error::{NeuralError, Result};
7use crate::layers::Layer;
8use crate::layers::Sequential;
9use crate::utils::colors::{colorize, stylize, Color, ColorOptions, Style};
10use scirs2_core::ndarray::ScalarOperand;
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13/// Represents a node in the model architecture graph
14#[derive(Debug, Clone)]
15struct ModelNode {
16    /// Layer name or description
17    name: String,
18    /// Input shape
19    inputshape: Option<Vec<usize>>,
20    /// Output shape
21    outputshape: Option<Vec<usize>>,
22    /// Number of parameters
23    parameters: Option<usize>,
24    /// Layer type
25    layer_type: String,
26    /// Additional properties
27    properties: Vec<(String, String)>,
28}
29/// Options for model architecture visualization
30pub struct ModelVizOptions {
31    /// Width of the visualization
32    pub width: usize,
33    /// Show parameter counts
34    pub show_params: bool,
35    /// Show layer shapes
36    pub showshapes: bool,
37    /// Show layer properties
38    pub show_properties: bool,
39    /// Color options
40    pub color_options: ColorOptions,
41}
42
43impl Default for ModelVizOptions {
44    fn default() -> Self {
45        Self {
46            width: 80,
47            show_params: true,
48            showshapes: true,
49            show_properties: true,
50            color_options: ColorOptions::default(),
51        }
52    }
53}
54
55/// Create an ASCII text representation of a sequential model architecture
56///
57/// # Arguments
58/// * `model` - The sequential model to visualize
59/// * `inputshape` - Optional input shape to propagate through the model
60/// * `title` - Optional title for the visualization
61/// * `options` - Visualization options
62/// # Returns
63/// * `Result<String>` - ASCII representation of the model architecture
64#[allow(dead_code)]
65pub fn sequential_model_summary<
66    F: Float + Debug + ScalarOperand + scirs2_core::numeric::FromPrimitive + std::fmt::Display,
67>(
68    model: &Sequential<F>,
69    inputshape: Option<Vec<usize>>,
70    title: Option<&str>,
71    options: Option<ModelVizOptions>,
72) -> Result<String> {
73    let options = options.unwrap_or_default();
74    // Width is used for column calculation later
75    let colors = &options.color_options;
76    let mut result = String::new();
77    // Add title
78    if let Some(titletext) = title {
79        if colors.enabled {
80            result.push_str(&stylize(titletext, Style::Bold));
81        } else {
82            result.push_str(titletext);
83        }
84        result.push_str("\n\n");
85    }
86    // Extract layer information
87    let layer_infos = model.layer_info();
88    if layer_infos.is_empty() {
89        return Err(NeuralError::ValidationError(
90            "Model has no layers".to_string(),
91        ));
92    }
93
94    // Create nodes for each layer
95    let mut nodes = Vec::new();
96    // Add input node if shape is provided
97    if let Some(shape) = inputshape.clone() {
98        nodes.push(ModelNode {
99            name: "Input".to_string(),
100            inputshape: None,
101            outputshape: Some(shape),
102            parameters: Some(0),
103            layer_type: "Input".to_string(),
104            properties: Vec::new(),
105        });
106    }
107
108    // Add actual layer nodes
109    for layer_info in &layer_infos {
110        let layer_name = if layer_info.name.starts_with("Layer_") {
111            let index = layer_info.index + 1;
112            format!("Layer {index}")
113        } else {
114            layer_info.name.clone()
115        };
116        // Create properties from layer info
117        let mut properties = Vec::new();
118        if let Some(ref inputshape) = layer_info.inputshape {
119            properties.push(("Input Shape".to_string(), format!("{inputshape:?}")));
120        }
121        if let Some(ref outputshape) = layer_info.outputshape {
122            properties.push(("Output Shape".to_string(), format!("{outputshape:?}")));
123        }
124
125        let node = ModelNode {
126            name: layer_name,
127            inputshape: layer_info.inputshape.clone(),
128            outputshape: layer_info.outputshape.clone(),
129            parameters: Some(layer_info.parameter_count),
130            layer_type: layer_info.layer_type.clone(),
131            properties,
132        };
133        nodes.push(node);
134    }
135
136    // Try to propagate shapes if input shape is provided
137    if let Some(inputshape) = inputshape {
138        // For now, simplified approach since we can't easily run the forward pass here
139        // In a full implementation, this would use actual layer logic
140        let mut currentshape = inputshape;
141        for (i, node) in nodes.iter_mut().enumerate() {
142            if i > 0 {
143                // Skip input node
144                node.inputshape = Some(currentshape.clone());
145                // Very simplified shape propagation (would need more detailed layer info)
146                if node.layer_type == "Dense" {
147                    if let Some(output_size) = extract_output_size(node) {
148                        // For Dense layers, output shape is (batch_size, output_size)
149                        if !currentshape.is_empty() {
150                            let mut outputshape = currentshape.clone();
151                            if outputshape.len() > 1 {
152                                let last_idx = outputshape.len() - 1;
153                                outputshape[last_idx] = output_size;
154                            } else {
155                                outputshape = vec![output_size];
156                            }
157                            currentshape = outputshape.clone();
158                            node.outputshape = Some(outputshape);
159                        }
160                    }
161                } else {
162                    // For other layer types, assume shape is preserved
163                    node.outputshape = Some(currentshape.clone());
164                }
165            }
166        }
167    }
168
169    // Calculate total parameters
170    let total_params: usize = nodes.iter().filter_map(|node| node.parameters).sum();
171    // Determine column widths
172    let name_width = nodes
173        .iter()
174        .map(|node| node.name.len())
175        .max()
176        .unwrap_or(10)
177        .max(10);
178    let type_width = nodes
179        .iter()
180        .map(|node| node.layer_type.len())
181        .max()
182        .unwrap_or(8)
183        .max(8);
184    let shape_width = if options.showshapes {
185        nodes
186            .iter()
187            .map(|node| {
188                let input_str = node.inputshape.as_ref().map(|s| format!("{s:?}"));
189                let output_str = node.outputshape.as_ref().map(|s| format!("{s:?}"));
190                let input_len = input_str.as_ref().map(|s| s.len()).unwrap_or(0);
191                let output_len = output_str.as_ref().map(|s| s.len()).unwrap_or(0);
192                input_len.max(output_len)
193            })
194            .max()
195            .unwrap_or(15)
196            .max(15)
197    } else {
198        0
199    };
200    let params_width = if options.show_params {
201        14 // Room for formatted parameter counts
202    } else {
203        0
204    };
205
206    // Add header
207    let mut header = format!(
208        "{:<width$} | {:<type_width$}",
209        if options.color_options.enabled {
210            stylize("Layer", Style::Bold).to_string()
211        } else {
212            "Layer".to_string()
213        },
214        if options.color_options.enabled {
215            stylize("Type", Style::Bold).to_string()
216        } else {
217            "Type".to_string()
218        },
219        width = name_width,
220        type_width = type_width
221    );
222    if options.showshapes {
223        header.push_str(&format!(
224            " | {:<shape_width$}",
225            if options.color_options.enabled {
226                stylize("Output Shape", Style::Bold).to_string()
227            } else {
228                "Output Shape".to_string()
229            },
230            shape_width = shape_width
231        ));
232    }
233    if options.show_params {
234        header.push_str(&format!(
235            " | {:<params_width$}",
236            if options.color_options.enabled {
237                stylize("Params", Style::Bold).to_string()
238            } else {
239                "Params".to_string()
240            },
241            params_width = params_width
242        ));
243    }
244
245    let mut result = String::new();
246    result.push_str(&header);
247    result.push('\n');
248    // Add separator
249    let total_width = name_width
250        + type_width
251        + (if options.showshapes {
252            shape_width + 3
253        } else {
254            0
255        })
256        + (if options.show_params {
257            params_width + 3
258        } else {
259            0
260        })
261        + 1;
262    result.push_str(&"-".repeat(total_width));
263    // Add layers
264    for node in &nodes {
265        // Layer name with color
266        let mut line = if options.color_options.enabled {
267            let styled_name = match node.layer_type.as_str() {
268                "Input" => colorize(&node.name, Color::BrightCyan),
269                "Dense" => colorize(&node.name, Color::BrightGreen),
270                "Conv2D" => colorize(&node.name, Color::BrightMagenta),
271                "RNN" | "LSTM" | "GRU" => colorize(&node.name, Color::BrightBlue),
272                "BatchNorm" | "Dropout" => colorize(&node.name, Color::Yellow),
273                _ => colorize(&node.name, Color::BrightWhite),
274            };
275            format!("{:<width$} | ", styled_name, width = name_width + 9) // Add space for ANSI codes
276        } else {
277            format!("{:<width$} | ", node.name, width = name_width)
278        };
279        // Layer type
280        line.push_str(&format!(
281            "{:<type_width$}",
282            node.layer_type,
283            type_width = type_width
284        ));
285
286        // Output shape
287        if options.showshapes {
288            let shape_str = if let Some(shape) = &node.outputshape {
289                format!("{shape:?}")
290            } else {
291                "?".to_string()
292            };
293            line.push_str(&format!(" | {shape_str:<shape_width$}"));
294        }
295
296        // Parameters
297        if options.show_params {
298            if let Some(params) = node.parameters {
299                let params_str = if params >= 1_000_000 {
300                    let param_mb = params as f64 / 1_000_000.0;
301                    format!("{param_mb:.2}M")
302                } else if params >= 1_000 {
303                    let param_kb = params as f64 / 1_000.0;
304                    format!("{param_kb:.2}K")
305                } else {
306                    format!("{params}")
307                };
308                line.push_str(&format!(" | {params_str:<params_width$}"));
309            } else {
310                line.push_str(&format!(" | {question:<params_width$}", question = "?"));
311            }
312        }
313
314        result.push_str(&line);
315        result.push('\n');
316        // Add properties if enabled
317        if options.show_properties && !node.properties.is_empty() {
318            for (key, value) in &node.properties {
319                let prop_line = if options.color_options.enabled {
320                    let styled_key = stylize(format!("  - {key}"), Style::Dim);
321                    format!("{styled_key}: {value}")
322                } else {
323                    format!("  - {key}: {value}")
324                };
325                result.push_str(&prop_line);
326                result.push('\n');
327            }
328        }
329    }
330
331    // Add summary information
332    let trainable_params = total_params; // For now, assume all are trainable
333    let formatted_total = format_params(total_params);
334    let summary = format!("Total parameters: {formatted_total}");
335    if options.color_options.enabled {
336        result.push_str(&stylize(&summary, Style::Bold));
337    } else {
338        result.push_str(&summary);
339    }
340    result.push('\n');
341
342    // Trainable parameters
343    let formatted_trainable = format_params(trainable_params);
344    let trainable_summary = format!("Trainable parameters: {formatted_trainable}");
345    if options.color_options.enabled {
346        result.push_str(&stylize(&trainable_summary, Style::Bold));
347    } else {
348        result.push_str(&trainable_summary);
349    }
350    result.push('\n');
351    // Non-trainable parameters
352    let non_trainable_params = total_params - trainable_params;
353    let non_trainable_summary = format!(
354        "Non-trainable parameters: {}",
355        format_params(non_trainable_params)
356    );
357    if options.color_options.enabled {
358        result.push_str(&stylize(&non_trainable_summary, Style::Bold));
359    } else {
360        result.push_str(&non_trainable_summary);
361    }
362    result.push('\n');
363
364    Ok(result)
365}
366/// Creates an ASCII representation of the data flow through a sequential model
367/// This visualization shows how data flows through the network layers,
368/// including transformations in shape and any connections between layers.
369#[allow(dead_code)]
370pub fn sequential_model_dataflow<
371    F: Float + Debug + ScalarOperand + scirs2_core::numeric::FromPrimitive + std::fmt::Display,
372>(
373    model: &Sequential<F>,
374    inputshape: Vec<usize>,
375    options: Option<ModelVizOptions>,
376) -> Result<String> {
377    let options = options.unwrap_or_default();
378    let width = options.width;
379    // Create nodes for visualization (input + layers)
380    let layer_infos = model.layer_info();
381    let mut nodes: Vec<ModelNode> = Vec::with_capacity(layer_infos.len() + 1);
382    // Add input node
383    nodes.push(ModelNode {
384        name: "Input".to_string(),
385        inputshape: None,
386        outputshape: Some(inputshape.clone()),
387        parameters: Some(0),
388        layer_type: "Input".to_string(),
389        properties: Vec::new(),
390    });
391    // Add layer nodes with simplified shape propagation
392    let mut currentshape = inputshape.clone();
393
394    for (i, layer_info) in layer_infos.iter().enumerate() {
395        let layer_name = if layer_info.name.starts_with("Layer_") {
396            let index = i + 1;
397            format!("Layer_{index}")
398        } else {
399            layer_info.name.clone()
400        };
401        let layer_type = layer_info.layer_type.clone();
402        let mut properties: Vec<(String, String)> = Vec::new();
403        if layer_info.parameter_count > 0 {
404            properties.push((
405                "Parameters".to_string(),
406                layer_info.parameter_count.to_string(),
407            ));
408        }
409        let inputshape = currentshape.clone();
410        // Very simplified shape inference
411        let outputshape = match layer_type.as_str() {
412            "Dense" => {
413                if let Some(output_size) = properties
414                    .iter()
415                    .find(|(key, _)| key == "output_dim")
416                    .map(|(_, value)| value.parse::<usize>().unwrap_or(0))
417                {
418                    if !currentshape.is_empty() {
419                        let mut newshape = currentshape.clone();
420                        let last_idx = newshape.len() - 1;
421                        newshape[last_idx] = output_size;
422                        newshape
423                    } else {
424                        vec![output_size]
425                    }
426                } else {
427                    currentshape.clone()
428                }
429            }
430            "Conv2D" => {
431                if currentshape.len() >= 3 {
432                    // Very simplified...in reality we'd need filter count, strides, etc.
433                    currentshape.clone()
434                } else {
435                    currentshape.clone()
436                }
437            }
438            _ => currentshape.clone(),
439        };
440
441        currentshape = outputshape.clone();
442
443        let node = ModelNode {
444            name: layer_name,
445            inputshape: Some(inputshape),
446            outputshape: Some(outputshape),
447            parameters: Some(0), // Simplified for now
448            layer_type,
449            properties,
450        };
451        nodes.push(node);
452    }
453    // Draw the data flow diagram
454    //
455    // Format:
456    //    ┌──────────────┐
457    //    │    Input     │
458    //    │  [batch, 28, 28, 1]  │
459    //    └──────┬───────┘
460    //           │
461    //           ▼
462    //    │    Conv2D    │
463    //    │ [b, 26, 26, 32] │
464    let mut result = String::new();
465    let box_width = 20.min(width / 2);
466
467    for (i, node) in nodes.iter().enumerate() {
468        // Draw box top
469        result.push_str(&" ".repeat((width - box_width) / 2));
470        result.push('┌');
471        result.push_str(&"─".repeat(box_width - 2));
472        result.push('┐');
473        result.push('\n');
474
475        // Draw layer name
476        let name = if node.layer_type == "Input" {
477            node.layer_type.clone()
478        } else {
479            format!("{} ({})", node.layer_type, node.name)
480        };
481        let padded_name = format!("{name:^width$}", width = box_width - 2);
482        result.push_str(&" ".repeat((width - box_width) / 2));
483
484        let styled_name = if options.color_options.enabled {
485            match node.layer_type.as_str() {
486                "Input" => colorize(&padded_name, Color::BrightCyan),
487                "Dense" => colorize(&padded_name, Color::BrightGreen),
488                "Conv2D" => colorize(&padded_name, Color::BrightMagenta),
489                "RNN" | "LSTM" | "GRU" => colorize(&padded_name, Color::BrightBlue),
490                "BatchNorm" | "Dropout" => colorize(&padded_name, Color::Yellow),
491                _ => padded_name.to_string(),
492            }
493        } else {
494            padded_name
495        };
496
497        result.push('│');
498        result.push_str(&styled_name);
499        result.push('│');
500        result.push('\n');
501
502        // Draw shape info
503        if let Some(shape) = &node.outputshape {
504            let shape_str = format!("{shape:?}");
505            let paddedshape = format!("{shape_str:^width$}", width = box_width - 2);
506            result.push_str(&" ".repeat((width - box_width) / 2));
507            result.push('│');
508            if options.color_options.enabled {
509                result.push_str(&stylize(&paddedshape, Style::Dim));
510            } else {
511                result.push_str(&paddedshape);
512            }
513            result.push('│');
514            result.push('\n');
515        }
516        // Draw box bottom
517        result.push_str(&" ".repeat((width - box_width) / 2));
518        result.push('└');
519        result.push_str(&"─".repeat(box_width - 2));
520        result.push('┘');
521        result.push('\n');
522
523        // Draw connector to next layer if not the last node
524        if i < nodes.len() - 1 {
525            result.push_str(&" ".repeat(width / 2));
526            result.push('│');
527            result.push('\n');
528            result.push_str(&" ".repeat(width / 2));
529            result.push('▼');
530            result.push('\n');
531        }
532    }
533
534    // Add summary if requested
535    let total_params: usize = nodes.iter().filter_map(|node| node.parameters).sum();
536    let formatted_total = format_params(total_params);
537    let summary = format!("Total parameters: {formatted_total}");
538    if options.color_options.enabled {
539        result.push_str(&stylize(&summary, Style::Bold));
540    } else {
541        result.push_str(&summary);
542    }
543    result.push('\n');
544
545    Ok(result)
546}
547// Helper function to extract output size from a layer's properties
548#[allow(dead_code)]
549fn extract_output_size(node: &ModelNode) -> Option<usize> {
550    if node.layer_type == "Dense" {
551        for (key, value) in &node.properties {
552            if key == "output_dim" {
553                return value.parse::<usize>().ok();
554            }
555        }
556    }
557    None
558}
559// Helper function to extract useful properties from a layer
560#[allow(dead_code)]
561fn extract_layer_properties<F: Float + Debug + ScalarOperand>(
562    layer: &(dyn Layer<F> + Send + Sync),
563) -> Vec<(String, String)> {
564    let mut properties = Vec::new();
565    let description = layer.layer_description();
566    // Very simple parsing of layer description
567    // In a real implementation, we'd want direct API access to layer properties
568    let parts: Vec<&str> = description.split(',').collect();
569    for part in parts {
570        let kv: Vec<&str> = part.split(':').collect();
571        if kv.len() == 2 {
572            let key = kv[0].trim().to_string();
573            let value = kv[1].trim().to_string();
574            if key != "type" && !key.is_empty() && !value.is_empty() {
575                properties.push((key, value));
576            }
577        }
578    }
579    properties
580}
581// Helper function to format parameter counts
582#[allow(dead_code)]
583fn format_params(params: usize) -> String {
584    if params >= 1_000_000 {
585        format!(
586            "{:.2}M ({} parameters)",
587            params as f64 / 1_000_000.0,
588            params
589        )
590    } else if params >= 1_000 {
591        let param_kb = params as f64 / 1_000.0;
592        format!("{param_kb:.2}K ({params} parameters)")
593    } else {
594        format!("{params} parameters")
595    }
596}