1use crate::error::{NeuralError, Result};
7use crate::layers::Layer;
8use crate::layers::Sequential;
9use crate::utils::colors::{colorize, stylize, Color, ColorOptions, Style};
10use scirs2_core::ndarray::ScalarOperand;
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13#[derive(Debug, Clone)]
15struct ModelNode {
16 name: String,
18 inputshape: Option<Vec<usize>>,
20 outputshape: Option<Vec<usize>>,
22 parameters: Option<usize>,
24 layer_type: String,
26 properties: Vec<(String, String)>,
28}
29pub struct ModelVizOptions {
31 pub width: usize,
33 pub show_params: bool,
35 pub showshapes: bool,
37 pub show_properties: bool,
39 pub color_options: ColorOptions,
41}
42
43impl Default for ModelVizOptions {
44 fn default() -> Self {
45 Self {
46 width: 80,
47 show_params: true,
48 showshapes: true,
49 show_properties: true,
50 color_options: ColorOptions::default(),
51 }
52 }
53}
54
55#[allow(dead_code)]
65pub fn sequential_model_summary<
66 F: Float + Debug + ScalarOperand + scirs2_core::numeric::FromPrimitive + std::fmt::Display,
67>(
68 model: &Sequential<F>,
69 inputshape: Option<Vec<usize>>,
70 title: Option<&str>,
71 options: Option<ModelVizOptions>,
72) -> Result<String> {
73 let options = options.unwrap_or_default();
74 let colors = &options.color_options;
76 let mut result = String::new();
77 if let Some(titletext) = title {
79 if colors.enabled {
80 result.push_str(&stylize(titletext, Style::Bold));
81 } else {
82 result.push_str(titletext);
83 }
84 result.push_str("\n\n");
85 }
86 let layer_infos = model.layer_info();
88 if layer_infos.is_empty() {
89 return Err(NeuralError::ValidationError(
90 "Model has no layers".to_string(),
91 ));
92 }
93
94 let mut nodes = Vec::new();
96 if let Some(shape) = inputshape.clone() {
98 nodes.push(ModelNode {
99 name: "Input".to_string(),
100 inputshape: None,
101 outputshape: Some(shape),
102 parameters: Some(0),
103 layer_type: "Input".to_string(),
104 properties: Vec::new(),
105 });
106 }
107
108 for layer_info in &layer_infos {
110 let layer_name = if layer_info.name.starts_with("Layer_") {
111 let index = layer_info.index + 1;
112 format!("Layer {index}")
113 } else {
114 layer_info.name.clone()
115 };
116 let mut properties = Vec::new();
118 if let Some(ref inputshape) = layer_info.inputshape {
119 properties.push(("Input Shape".to_string(), format!("{inputshape:?}")));
120 }
121 if let Some(ref outputshape) = layer_info.outputshape {
122 properties.push(("Output Shape".to_string(), format!("{outputshape:?}")));
123 }
124
125 let node = ModelNode {
126 name: layer_name,
127 inputshape: layer_info.inputshape.clone(),
128 outputshape: layer_info.outputshape.clone(),
129 parameters: Some(layer_info.parameter_count),
130 layer_type: layer_info.layer_type.clone(),
131 properties,
132 };
133 nodes.push(node);
134 }
135
136 if let Some(inputshape) = inputshape {
138 let mut currentshape = inputshape;
141 for (i, node) in nodes.iter_mut().enumerate() {
142 if i > 0 {
143 node.inputshape = Some(currentshape.clone());
145 if node.layer_type == "Dense" {
147 if let Some(output_size) = extract_output_size(node) {
148 if !currentshape.is_empty() {
150 let mut outputshape = currentshape.clone();
151 if outputshape.len() > 1 {
152 let last_idx = outputshape.len() - 1;
153 outputshape[last_idx] = output_size;
154 } else {
155 outputshape = vec![output_size];
156 }
157 currentshape = outputshape.clone();
158 node.outputshape = Some(outputshape);
159 }
160 }
161 } else {
162 node.outputshape = Some(currentshape.clone());
164 }
165 }
166 }
167 }
168
169 let total_params: usize = nodes.iter().filter_map(|node| node.parameters).sum();
171 let name_width = nodes
173 .iter()
174 .map(|node| node.name.len())
175 .max()
176 .unwrap_or(10)
177 .max(10);
178 let type_width = nodes
179 .iter()
180 .map(|node| node.layer_type.len())
181 .max()
182 .unwrap_or(8)
183 .max(8);
184 let shape_width = if options.showshapes {
185 nodes
186 .iter()
187 .map(|node| {
188 let input_str = node.inputshape.as_ref().map(|s| format!("{s:?}"));
189 let output_str = node.outputshape.as_ref().map(|s| format!("{s:?}"));
190 let input_len = input_str.as_ref().map(|s| s.len()).unwrap_or(0);
191 let output_len = output_str.as_ref().map(|s| s.len()).unwrap_or(0);
192 input_len.max(output_len)
193 })
194 .max()
195 .unwrap_or(15)
196 .max(15)
197 } else {
198 0
199 };
200 let params_width = if options.show_params {
201 14 } else {
203 0
204 };
205
206 let mut header = format!(
208 "{:<width$} | {:<type_width$}",
209 if options.color_options.enabled {
210 stylize("Layer", Style::Bold).to_string()
211 } else {
212 "Layer".to_string()
213 },
214 if options.color_options.enabled {
215 stylize("Type", Style::Bold).to_string()
216 } else {
217 "Type".to_string()
218 },
219 width = name_width,
220 type_width = type_width
221 );
222 if options.showshapes {
223 header.push_str(&format!(
224 " | {:<shape_width$}",
225 if options.color_options.enabled {
226 stylize("Output Shape", Style::Bold).to_string()
227 } else {
228 "Output Shape".to_string()
229 },
230 shape_width = shape_width
231 ));
232 }
233 if options.show_params {
234 header.push_str(&format!(
235 " | {:<params_width$}",
236 if options.color_options.enabled {
237 stylize("Params", Style::Bold).to_string()
238 } else {
239 "Params".to_string()
240 },
241 params_width = params_width
242 ));
243 }
244
245 let mut result = String::new();
246 result.push_str(&header);
247 result.push('\n');
248 let total_width = name_width
250 + type_width
251 + (if options.showshapes {
252 shape_width + 3
253 } else {
254 0
255 })
256 + (if options.show_params {
257 params_width + 3
258 } else {
259 0
260 })
261 + 1;
262 result.push_str(&"-".repeat(total_width));
263 for node in &nodes {
265 let mut line = if options.color_options.enabled {
267 let styled_name = match node.layer_type.as_str() {
268 "Input" => colorize(&node.name, Color::BrightCyan),
269 "Dense" => colorize(&node.name, Color::BrightGreen),
270 "Conv2D" => colorize(&node.name, Color::BrightMagenta),
271 "RNN" | "LSTM" | "GRU" => colorize(&node.name, Color::BrightBlue),
272 "BatchNorm" | "Dropout" => colorize(&node.name, Color::Yellow),
273 _ => colorize(&node.name, Color::BrightWhite),
274 };
275 format!("{:<width$} | ", styled_name, width = name_width + 9) } else {
277 format!("{:<width$} | ", node.name, width = name_width)
278 };
279 line.push_str(&format!(
281 "{:<type_width$}",
282 node.layer_type,
283 type_width = type_width
284 ));
285
286 if options.showshapes {
288 let shape_str = if let Some(shape) = &node.outputshape {
289 format!("{shape:?}")
290 } else {
291 "?".to_string()
292 };
293 line.push_str(&format!(" | {shape_str:<shape_width$}"));
294 }
295
296 if options.show_params {
298 if let Some(params) = node.parameters {
299 let params_str = if params >= 1_000_000 {
300 let param_mb = params as f64 / 1_000_000.0;
301 format!("{param_mb:.2}M")
302 } else if params >= 1_000 {
303 let param_kb = params as f64 / 1_000.0;
304 format!("{param_kb:.2}K")
305 } else {
306 format!("{params}")
307 };
308 line.push_str(&format!(" | {params_str:<params_width$}"));
309 } else {
310 line.push_str(&format!(" | {question:<params_width$}", question = "?"));
311 }
312 }
313
314 result.push_str(&line);
315 result.push('\n');
316 if options.show_properties && !node.properties.is_empty() {
318 for (key, value) in &node.properties {
319 let prop_line = if options.color_options.enabled {
320 let styled_key = stylize(format!(" - {key}"), Style::Dim);
321 format!("{styled_key}: {value}")
322 } else {
323 format!(" - {key}: {value}")
324 };
325 result.push_str(&prop_line);
326 result.push('\n');
327 }
328 }
329 }
330
331 let trainable_params = total_params; let formatted_total = format_params(total_params);
334 let summary = format!("Total parameters: {formatted_total}");
335 if options.color_options.enabled {
336 result.push_str(&stylize(&summary, Style::Bold));
337 } else {
338 result.push_str(&summary);
339 }
340 result.push('\n');
341
342 let formatted_trainable = format_params(trainable_params);
344 let trainable_summary = format!("Trainable parameters: {formatted_trainable}");
345 if options.color_options.enabled {
346 result.push_str(&stylize(&trainable_summary, Style::Bold));
347 } else {
348 result.push_str(&trainable_summary);
349 }
350 result.push('\n');
351 let non_trainable_params = total_params - trainable_params;
353 let non_trainable_summary = format!(
354 "Non-trainable parameters: {}",
355 format_params(non_trainable_params)
356 );
357 if options.color_options.enabled {
358 result.push_str(&stylize(&non_trainable_summary, Style::Bold));
359 } else {
360 result.push_str(&non_trainable_summary);
361 }
362 result.push('\n');
363
364 Ok(result)
365}
366#[allow(dead_code)]
370pub fn sequential_model_dataflow<
371 F: Float + Debug + ScalarOperand + scirs2_core::numeric::FromPrimitive + std::fmt::Display,
372>(
373 model: &Sequential<F>,
374 inputshape: Vec<usize>,
375 options: Option<ModelVizOptions>,
376) -> Result<String> {
377 let options = options.unwrap_or_default();
378 let width = options.width;
379 let layer_infos = model.layer_info();
381 let mut nodes: Vec<ModelNode> = Vec::with_capacity(layer_infos.len() + 1);
382 nodes.push(ModelNode {
384 name: "Input".to_string(),
385 inputshape: None,
386 outputshape: Some(inputshape.clone()),
387 parameters: Some(0),
388 layer_type: "Input".to_string(),
389 properties: Vec::new(),
390 });
391 let mut currentshape = inputshape.clone();
393
394 for (i, layer_info) in layer_infos.iter().enumerate() {
395 let layer_name = if layer_info.name.starts_with("Layer_") {
396 let index = i + 1;
397 format!("Layer_{index}")
398 } else {
399 layer_info.name.clone()
400 };
401 let layer_type = layer_info.layer_type.clone();
402 let mut properties: Vec<(String, String)> = Vec::new();
403 if layer_info.parameter_count > 0 {
404 properties.push((
405 "Parameters".to_string(),
406 layer_info.parameter_count.to_string(),
407 ));
408 }
409 let inputshape = currentshape.clone();
410 let outputshape = match layer_type.as_str() {
412 "Dense" => {
413 if let Some(output_size) = properties
414 .iter()
415 .find(|(key, _)| key == "output_dim")
416 .map(|(_, value)| value.parse::<usize>().unwrap_or(0))
417 {
418 if !currentshape.is_empty() {
419 let mut newshape = currentshape.clone();
420 let last_idx = newshape.len() - 1;
421 newshape[last_idx] = output_size;
422 newshape
423 } else {
424 vec![output_size]
425 }
426 } else {
427 currentshape.clone()
428 }
429 }
430 "Conv2D" => {
431 if currentshape.len() >= 3 {
432 currentshape.clone()
434 } else {
435 currentshape.clone()
436 }
437 }
438 _ => currentshape.clone(),
439 };
440
441 currentshape = outputshape.clone();
442
443 let node = ModelNode {
444 name: layer_name,
445 inputshape: Some(inputshape),
446 outputshape: Some(outputshape),
447 parameters: Some(0), layer_type,
449 properties,
450 };
451 nodes.push(node);
452 }
453 let mut result = String::new();
465 let box_width = 20.min(width / 2);
466
467 for (i, node) in nodes.iter().enumerate() {
468 result.push_str(&" ".repeat((width - box_width) / 2));
470 result.push('┌');
471 result.push_str(&"─".repeat(box_width - 2));
472 result.push('┐');
473 result.push('\n');
474
475 let name = if node.layer_type == "Input" {
477 node.layer_type.clone()
478 } else {
479 format!("{} ({})", node.layer_type, node.name)
480 };
481 let padded_name = format!("{name:^width$}", width = box_width - 2);
482 result.push_str(&" ".repeat((width - box_width) / 2));
483
484 let styled_name = if options.color_options.enabled {
485 match node.layer_type.as_str() {
486 "Input" => colorize(&padded_name, Color::BrightCyan),
487 "Dense" => colorize(&padded_name, Color::BrightGreen),
488 "Conv2D" => colorize(&padded_name, Color::BrightMagenta),
489 "RNN" | "LSTM" | "GRU" => colorize(&padded_name, Color::BrightBlue),
490 "BatchNorm" | "Dropout" => colorize(&padded_name, Color::Yellow),
491 _ => padded_name.to_string(),
492 }
493 } else {
494 padded_name
495 };
496
497 result.push('│');
498 result.push_str(&styled_name);
499 result.push('│');
500 result.push('\n');
501
502 if let Some(shape) = &node.outputshape {
504 let shape_str = format!("{shape:?}");
505 let paddedshape = format!("{shape_str:^width$}", width = box_width - 2);
506 result.push_str(&" ".repeat((width - box_width) / 2));
507 result.push('│');
508 if options.color_options.enabled {
509 result.push_str(&stylize(&paddedshape, Style::Dim));
510 } else {
511 result.push_str(&paddedshape);
512 }
513 result.push('│');
514 result.push('\n');
515 }
516 result.push_str(&" ".repeat((width - box_width) / 2));
518 result.push('└');
519 result.push_str(&"─".repeat(box_width - 2));
520 result.push('┘');
521 result.push('\n');
522
523 if i < nodes.len() - 1 {
525 result.push_str(&" ".repeat(width / 2));
526 result.push('│');
527 result.push('\n');
528 result.push_str(&" ".repeat(width / 2));
529 result.push('▼');
530 result.push('\n');
531 }
532 }
533
534 let total_params: usize = nodes.iter().filter_map(|node| node.parameters).sum();
536 let formatted_total = format_params(total_params);
537 let summary = format!("Total parameters: {formatted_total}");
538 if options.color_options.enabled {
539 result.push_str(&stylize(&summary, Style::Bold));
540 } else {
541 result.push_str(&summary);
542 }
543 result.push('\n');
544
545 Ok(result)
546}
547#[allow(dead_code)]
549fn extract_output_size(node: &ModelNode) -> Option<usize> {
550 if node.layer_type == "Dense" {
551 for (key, value) in &node.properties {
552 if key == "output_dim" {
553 return value.parse::<usize>().ok();
554 }
555 }
556 }
557 None
558}
559#[allow(dead_code)]
561fn extract_layer_properties<F: Float + Debug + ScalarOperand>(
562 layer: &(dyn Layer<F> + Send + Sync),
563) -> Vec<(String, String)> {
564 let mut properties = Vec::new();
565 let description = layer.layer_description();
566 let parts: Vec<&str> = description.split(',').collect();
569 for part in parts {
570 let kv: Vec<&str> = part.split(':').collect();
571 if kv.len() == 2 {
572 let key = kv[0].trim().to_string();
573 let value = kv[1].trim().to_string();
574 if key != "type" && !key.is_empty() && !value.is_empty() {
575 properties.push((key, value));
576 }
577 }
578 }
579 properties
580}
581#[allow(dead_code)]
583fn format_params(params: usize) -> String {
584 if params >= 1_000_000 {
585 format!(
586 "{:.2}M ({} parameters)",
587 params as f64 / 1_000_000.0,
588 params
589 )
590 } else if params >= 1_000 {
591 let param_kb = params as f64 / 1_000.0;
592 format!("{param_kb:.2}K ({params} parameters)")
593 } else {
594 format!("{params} parameters")
595 }
596}