1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use box_drawing::heavy::*;
4use nu_ansi_term::{Color, Style};
5use std::collections::HashSet;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone, Debug)]
11struct HiddenWire {
12 successors: Vec<InletId>,
13}
14
15#[derive(Clone, Debug)]
17struct VisibleWire {
18 outlet: OutletId,
19 color: Style,
20 successors: Vec<InletId>,
21 should_change_color: bool,
22}
23
24pub fn circled_input(ix: usize) -> char {
26 match ix {
27 0 => '⓪',
28 1..=20 => char::from_u32(0x2460 + (ix as u32 - 1)).unwrap(),
29 _ => '○',
30 }
31}
32
33pub fn circled_output(ix: usize) -> char {
35 match ix {
36 0 => '⓿',
37 1..=10 => char::from_u32(0x2776 + (ix as u32 - 1)).unwrap(),
38 _ => '●',
39 }
40}
41
42#[derive(Clone, Default)]
43pub struct DrawingState {
44 hidden: Vec<HiddenWire>,
45 visible: Vec<VisibleWire>, latest_node_color: Style,
47 visited: HashSet<usize>,
48}
49
50impl DrawingState {
51 fn next_color(&self) -> Style {
52 let colors = &[
53 Color::Red.normal(),
54 Color::Green.normal(),
55 Color::Yellow.normal(),
56 Color::Blue.normal(),
57 Color::Purple.normal(),
58 Color::Cyan.normal(),
59 Color::White.normal(),
60 Color::Red.bold(),
61 Color::Green.bold(),
62 Color::Yellow.bold(),
63 Color::Blue.bold(),
64 Color::Purple.bold(),
65 Color::Cyan.bold(),
66 Color::White.bold(),
67 ];
68 *colors
69 .iter()
70 .min_by_key(|&c| self.visible.iter().filter(|w| w.color == *c).count())
71 .unwrap()
72 }
73
74 fn passthrough_count(&self, node: usize) -> usize {
76 self.visible.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
77 }
78
79 pub fn last_wire_color(&self) -> Style {
81 self.visible.last().map(|w| w.color).unwrap_or(self.latest_node_color)
82 }
83
84 fn render_filler(&self) -> String {
86 let mut s = String::new();
87 for w in &self.visible {
88 let _ = write!(s, "{}", w.color.paint(VERTICAL));
89 }
90 s
91 }
92
93 pub fn draw_node_vprefix(
94 &mut self,
95 model: &dyn Model,
96 node: usize,
97 _opts: &DisplayParams,
98 ) -> TractResult<Vec<String>> {
99 let mut lines = vec![];
100
101 self.visible.retain(|w| w.successors.iter().any(|i| !self.visited.contains(&i.node)));
103 self.hidden.retain(|w| w.successors.iter().any(|i| !self.visited.contains(&i.node)));
104
105 let inputs = model.node_inputs(node);
107 let mut passthroughs: Vec<VisibleWire> = Vec::new();
108 let mut input_wires: Vec<Option<VisibleWire>> = vec![None; inputs.len()];
109
110 for w in &self.visible {
111 let mut matched_input = None;
113 for (ix, &inlet) in inputs.iter().enumerate() {
114 if w.outlet == inlet {
115 matched_input = Some(ix);
116 break;
117 }
118 }
119
120 if let Some(ix) = matched_input {
121 let this_inlet = InletId::new(node, ix);
122 let must_clone = w.successors.iter().any(|i| *i != this_inlet);
123 if must_clone {
124 let mut pass_wire = w.clone();
127 pass_wire.successors.retain(|i| *i != this_inlet);
128 passthroughs.push(pass_wire);
129 input_wires[ix] = Some(VisibleWire {
130 outlet: w.outlet,
131 color: w.color,
132 successors: vec![this_inlet],
133 should_change_color: true,
134 });
135 } else {
136 input_wires[ix] = Some(w.clone());
138 }
139 } else {
140 passthroughs.push(w.clone());
141 }
142 }
143
144 let pt = passthroughs.len();
146 let mut target: Vec<VisibleWire> = passthroughs;
147 for w in input_wires.iter().flatten() {
148 target.push(w.clone());
149 }
150
151 let n_inputs_visible = input_wires.iter().filter(|w| w.is_some()).count();
154 let total_cols = pt + n_inputs_visible;
155 let mut slots: Vec<Option<VisibleWire>> = Vec::with_capacity(total_cols);
156 for w in &self.visible {
157 slots.push(Some(w.clone()));
158 }
159 while slots.len() < total_cols {
160 slots.push(None); }
162
163 for (ix, &inlet) in inputs.iter().enumerate().rev() {
168 let Some(ref input_wire) = input_wires[ix] else { continue };
169
170 let target_col = target
171 .iter()
172 .position(|w| w.outlet == inlet && w.successors.iter().any(|i| i.node == node))
173 .unwrap();
174
175 let cur_col =
176 match slots.iter().position(|s| s.as_ref().is_some_and(|w| w.outlet == inlet)) {
177 Some(c) => c,
178 None => continue,
179 };
180
181 let must_clone = input_wire.should_change_color; if cur_col == target_col && !must_clone {
184 continue;
185 }
186
187 let mut s = String::new();
189 let color = slots[cur_col].as_ref().unwrap().color;
190 let from = cur_col.min(target_col);
191 let to = cur_col.max(target_col);
192
193 for w in slots[..from].iter().flatten() {
195 let _ = write!(s, "{}", w.color.paint(VERTICAL));
196 }
197
198 if must_clone {
199 let _ = write!(s, "{}", color.paint(VERTICAL_RIGHT));
201 } else {
202 let _ = write!(s, "{}", color.paint(UP_RIGHT));
204 }
205 for _ in from + 1..to {
206 let _ = write!(s, "{}", color.paint(HORIZONTAL));
207 }
208 let _ = write!(s, "{}", color.paint(DOWN_LEFT));
209
210 for w in slots[to + 1..].iter().flatten() {
212 let _ = write!(s, "{}", w.color.paint(VERTICAL));
213 }
214
215 lines.push(s);
216
217 if must_clone {
219 slots[target_col] = Some(input_wire.clone());
221 } else {
222 slots[cur_col] = None;
224 slots[target_col] = Some(input_wire.clone());
225 }
226 }
227
228 self.visible = target;
230
231 lines.retain(|l: &String| !l.trim().is_empty());
232 Ok(lines)
233 }
234
235 pub fn draw_node_body(
236 &mut self,
237 model: &dyn Model,
238 node: usize,
239 opts: &DisplayParams,
240 ) -> TractResult<Vec<String>> {
241 let mut lines = vec![String::new()];
242 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
243 macro_rules! ln {
244 () => {
245 lines.push(String::new())
246 };
247 }
248
249 let inputs = model.node_inputs(node).to_vec();
250 let passthrough_count = self.passthrough_count(node);
251 let display = opts.konst || !model.node_const(node);
252
253 if display {
254 for w in &self.visible[..passthrough_count] {
256 p!("{}", w.color.paint(VERTICAL));
257 }
258
259 let node_output_count = model.node_output_count(node);
260
261 self.latest_node_color = if !inputs.is_empty() && passthrough_count < self.visible.len()
263 {
264 let wire0 = &self.visible[passthrough_count];
265 if !wire0.should_change_color { wire0.color } else { self.next_color() }
266 } else {
267 self.next_color()
268 };
269
270 match (inputs.len(), node_output_count) {
272 (0, 1) => {
273 let input_idx = model.input_outlets().iter().position(|o| o.node == node);
275 let symbol = match input_idx {
276 Some(i) => circled_input(i).to_string(),
277 _ => DOWN_RIGHT.to_string(),
278 };
279 p!("{}", self.latest_node_color.paint(symbol));
280 }
281 (1, 0) => {
282 p!("{}", self.latest_node_color.paint("╹"));
283 }
284 (u, d) => {
285 p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
286 for _ in 1..u.min(d) {
287 p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
288 }
289 for _ in u..d {
290 p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
291 }
292 for _ in d..u {
293 p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
294 }
295 }
296 }
297 ln!();
298 }
299
300 while lines.last().map(|s| s.trim()) == Some("") {
301 lines.pop();
302 }
303 Ok(lines)
304 }
305
306 pub fn draw_node_vfiller(&self, _model: &dyn Model, _node: usize) -> TractResult<String> {
307 Ok(self.render_filler())
308 }
309
310 pub fn draw_node_vsuffix(
311 &mut self,
312 model: &dyn Model,
313 node: usize,
314 opts: &DisplayParams,
315 ) -> TractResult<Vec<String>> {
316 self.visited.insert(node);
318 let mut lines = vec![];
319 let passthrough_count = self.passthrough_count(node);
320 let node_output_count = model.node_output_count(node);
321
322 self.visible.truncate(passthrough_count);
324
325 for slot in 0..node_output_count {
327 let outlet = OutletId::new(node, slot);
328 let successors = model.outlet_successors(outlet).to_vec();
329 let color = if !opts.konst && model.node_const(node) {
330 self.hidden.push(HiddenWire { successors });
332 continue;
333 } else if slot == 0 {
334 self.latest_node_color
335 } else {
336 self.next_color()
337 };
338 self.visible.push(VisibleWire {
339 outlet,
340 color,
341 successors,
342 should_change_color: false,
343 });
344 }
345
346 let model_outputs = model.output_outlets();
348 let has_output_marker = self.visible.iter().any(|w| model_outputs.contains(&w.outlet));
349 if has_output_marker {
350 let mut s = String::new();
351 for w in &self.visible {
352 if model_outputs.contains(&w.outlet) {
353 let output_idx = model_outputs.iter().position(|o| *o == w.outlet);
354 let symbol = match output_idx {
355 Some(i) => circled_output(i),
356 _ => '●',
357 };
358 let _ = write!(s, "{}", w.color.paint(symbol.to_string()));
359 } else {
360 let _ = write!(s, "{}", w.color.paint(VERTICAL));
361 }
362 }
363 lines.push(s);
364 }
365
366 self.visible.retain(|w| !w.successors.is_empty());
368
369 lines.retain(|l: &String| !l.trim().is_empty());
370 Ok(lines)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::display_params::DisplayParams;
378 use crate::model::Model;
379 use tract_core::ops::identity::Identity;
380 use tract_core::ops::math;
381
382 fn strip_ansi(s: &str) -> String {
383 let mut out = String::new();
384 let mut in_escape = false;
385 for c in s.chars() {
386 if in_escape {
387 if c == 'm' {
388 in_escape = false;
389 }
390 } else if c == '\x1b' {
391 in_escape = true;
392 } else {
393 out.push(c);
394 }
395 }
396 out
397 }
398
399 fn draw_all(model: &dyn Model, ds: &mut DrawingState, node: usize) -> Vec<String> {
400 let opts = DisplayParams { konst: true, ..DisplayParams::default() };
401 let mut lines = vec![];
402 for l in ds.draw_node_vprefix(model, node, &opts).unwrap() {
403 lines.push(strip_ansi(&l));
404 }
405 for l in ds.draw_node_body(model, node, &opts).unwrap() {
406 lines.push(strip_ansi(&l));
407 }
408 for l in ds.draw_node_vsuffix(model, node, &opts).unwrap() {
409 lines.push(strip_ansi(&l));
410 }
411 lines.retain(|l| !l.trim().is_empty());
412 lines
413 }
414
415 #[test]
417 fn linear_chain() -> TractResult<()> {
418 let mut model = TypedModel::default();
419 let s = model.add_source("s", f32::fact([1]))?;
420 let _id = model.wire_node("id", Identity, &[s])?[0];
421 model.auto_outputs()?;
422 let mut ds = DrawingState::default();
423 let lines0 = draw_all(&model, &mut ds, 0);
424 assert_eq!(lines0, vec!["⓪"]); let lines1 = draw_all(&model, &mut ds, 1);
426 assert_eq!(lines1[0], VERTICAL_RIGHT); assert!(lines1.len() == 2 && lines1[1] == "⓿"); Ok(())
429 }
430
431 #[test]
433 fn fanin_from_one_source() -> TractResult<()> {
434 let mut model = TypedModel::default();
435 let s = model.add_source("s", f32::fact([1]))?;
436 let _add = model.wire_node("add", math::add(), &[s, s])?[0];
437 model.auto_outputs()?;
438 let mut ds = DrawingState::default();
439 let lines0 = draw_all(&model, &mut ds, 0);
440 assert_eq!(lines0, vec!["⓪"]); let lines1 = draw_all(&model, &mut ds, 1);
442 let joined = lines1.join("|");
443 assert!(
444 joined.contains(UP_HORIZONTAL), "Expected merge pattern, got: {lines1:?}"
446 );
447 Ok(())
448 }
449
450 #[test]
452 fn fork_after_merge() -> TractResult<()> {
453 let mut model = TypedModel::default();
454 let a = model.add_source("a", f32::fact([1]))?;
455 let b = model.add_source("b", f32::fact([1]))?;
456 let add = model.wire_node("add", math::add(), &[a, b])?[0];
457 let _id1 = model.wire_node("id1", Identity, &[add])?[0];
458 let _id2 = model.wire_node("id2", Identity, &[add])?[0];
459 model.auto_outputs()?;
460 let mut ds = DrawingState::default();
461 draw_all(&model, &mut ds, 0); draw_all(&model, &mut ds, 1); let lines_add = draw_all(&model, &mut ds, 2); let joined = lines_add.join("|");
465 assert!(
466 joined.contains(UP_HORIZONTAL), "Expected merge in body, got: {lines_add:?}"
468 );
469 let lines_id1 = draw_all(&model, &mut ds, 3); assert!(!lines_id1.is_empty(), "id1 should render");
471 Ok(())
472 }
473
474 #[test]
476 fn no_blank_prefix_lines() -> TractResult<()> {
477 let mut model = TypedModel::default();
478 let a = model.add_source("a", f32::fact([1]))?;
479 let b = model.add_source("b", f32::fact([1]))?;
480 let add = model.wire_node("add", math::add(), &[a, b])?[0];
481 let _id = model.wire_node("id", Identity, &[add])?[0];
482 model.auto_outputs()?;
483 let opts = DisplayParams { konst: true, ..DisplayParams::default() };
484 let mut ds = DrawingState::default();
485 let order = model.eval_order()?;
486 for &node in &order {
487 let prefix = ds.draw_node_vprefix(&model, node, &opts).unwrap();
488 for (i, l) in prefix.iter().enumerate() {
489 let stripped = strip_ansi(l);
490 assert!(
491 !stripped.trim().is_empty() || i == prefix.len() - 1,
492 "Blank line at position {i} in prefix for node {node}: {prefix:?}"
493 );
494 }
495 ds.draw_node_body(&model, node, &opts).unwrap();
496 ds.draw_node_vsuffix(&model, node, &opts).unwrap();
497 }
498 Ok(())
499 }
500
501 #[test]
503 fn filler_width_matches_visible() -> TractResult<()> {
504 let mut model = TypedModel::default();
505 let a = model.add_source("a", f32::fact([1]))?;
506 let b = model.add_source("b", f32::fact([1]))?;
507 let add = model.wire_node("add", math::add(), &[a, b])?[0];
508 let _id1 = model.wire_node("id1", Identity, &[add])?[0];
509 let _id2 = model.wire_node("id2", Identity, &[add])?[0];
510 model.auto_outputs()?;
511 let opts = DisplayParams { konst: true, ..DisplayParams::default() };
512 let mut ds = DrawingState::default();
513 let order = model.eval_order()?;
514 for &node in &order {
515 ds.draw_node_vprefix(&model, node, &opts).unwrap();
516 ds.draw_node_body(&model, node, &opts).unwrap();
517 ds.draw_node_vsuffix(&model, node, &opts).unwrap();
518 let filler = ds.draw_node_vfiller(&model, node).unwrap();
519 let filler_w = strip_ansi(&filler).chars().count();
520 let visible_count = ds.visible.len();
521 assert_eq!(
522 filler_w, visible_count,
523 "Filler width {filler_w} != visible wire count {visible_count} for node {node}"
524 );
525 }
526 Ok(())
527 }
528}