Skip to main content

tract_libcli/
draw.rs

1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use box_drawing::heavy::*;
4use nu_ansi_term::{Color, Style};
5use std::collections::HashSet;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9/// A wire that is not rendered (const node output when konst=false).
10#[derive(Clone, Debug)]
11struct HiddenWire {
12    successors: Vec<InletId>,
13}
14
15/// A wire that occupies a visual column.
16#[derive(Clone, Debug)]
17struct VisibleWire {
18    outlet: OutletId,
19    color: Style,
20    successors: Vec<InletId>,
21    should_change_color: bool,
22}
23
24/// White circled number for model inputs: ⓪①②...⑳
25pub fn circled_input(ix: usize) -> char {
26    match ix {
27        0 => '⓪',
28        1..=20 => char::from_u32(0x2460 + (ix as u32 - 1)).unwrap(),
29        _ => '○',
30    }
31}
32
33/// Filled circled number for model outputs: ⓿❶❷...❿
34pub fn circled_output(ix: usize) -> char {
35    match ix {
36        0 => '⓿',
37        1..=10 => char::from_u32(0x2776 + (ix as u32 - 1)).unwrap(),
38        _ => '●',
39    }
40}
41
42#[derive(Clone, Default)]
43pub struct DrawingState {
44    hidden: Vec<HiddenWire>,
45    visible: Vec<VisibleWire>, // index = visual column
46    latest_node_color: Style,
47    visited: HashSet<usize>,
48}
49
50impl DrawingState {
51    fn next_color(&self) -> Style {
52        let colors = &[
53            Color::Red.normal(),
54            Color::Green.normal(),
55            Color::Yellow.normal(),
56            Color::Blue.normal(),
57            Color::Purple.normal(),
58            Color::Cyan.normal(),
59            Color::White.normal(),
60            Color::Red.bold(),
61            Color::Green.bold(),
62            Color::Yellow.bold(),
63            Color::Blue.bold(),
64            Color::Purple.bold(),
65            Color::Cyan.bold(),
66            Color::White.bold(),
67        ];
68        *colors
69            .iter()
70            .min_by_key(|&c| self.visible.iter().filter(|w| w.color == *c).count())
71            .unwrap()
72    }
73
74    /// Number of visible wires that pass through (have successors to nodes other than `node`).
75    fn passthrough_count(&self, node: usize) -> usize {
76        self.visible.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
77    }
78
79    /// Color of the last visible wire, or the latest node color.
80    pub fn last_wire_color(&self) -> Style {
81        self.visible.last().map(|w| w.color).unwrap_or(self.latest_node_color)
82    }
83
84    /// Render a filler line: one ┃ per visible wire.
85    fn render_filler(&self) -> String {
86        let mut s = String::new();
87        for w in &self.visible {
88            let _ = write!(s, "{}", w.color.paint(VERTICAL));
89        }
90        s
91    }
92
93    pub fn draw_node_vprefix(
94        &mut self,
95        model: &dyn Model,
96        node: usize,
97        _opts: &DisplayParams,
98    ) -> TractResult<Vec<String>> {
99        let mut lines = vec![];
100
101        // Prune wires whose only remaining successors are all already visited.
102        self.visible.retain(|w| w.successors.iter().any(|i| !self.visited.contains(&i.node)));
103        self.hidden.retain(|w| w.successors.iter().any(|i| !self.visited.contains(&i.node)));
104
105        // Build target layout: passthroughs in current order, then visible inputs in input order.
106        let inputs = model.node_inputs(node);
107        let mut passthroughs: Vec<VisibleWire> = Vec::new();
108        let mut input_wires: Vec<Option<VisibleWire>> = vec![None; inputs.len()];
109
110        for w in &self.visible {
111            // Check if this wire feeds any input of this node
112            let mut matched_input = None;
113            for (ix, &inlet) in inputs.iter().enumerate() {
114                if w.outlet == inlet {
115                    matched_input = Some(ix);
116                    break;
117                }
118            }
119
120            if let Some(ix) = matched_input {
121                let this_inlet = InletId::new(node, ix);
122                let must_clone = w.successors.iter().any(|i| *i != this_inlet);
123                if must_clone {
124                    // Wire feeds this node AND others: clone it.
125                    // Original (with other successors) stays as passthrough.
126                    let mut pass_wire = w.clone();
127                    pass_wire.successors.retain(|i| *i != this_inlet);
128                    passthroughs.push(pass_wire);
129                    input_wires[ix] = Some(VisibleWire {
130                        outlet: w.outlet,
131                        color: w.color,
132                        successors: vec![this_inlet],
133                        should_change_color: true,
134                    });
135                } else {
136                    // Wire feeds only this node: move entirely to input position.
137                    input_wires[ix] = Some(w.clone());
138                }
139            } else {
140                passthroughs.push(w.clone());
141            }
142        }
143
144        // Target = passthroughs ++ visible input wires
145        let pt = passthroughs.len();
146        let mut target: Vec<VisibleWire> = passthroughs;
147        for w in input_wires.iter().flatten() {
148            target.push(w.clone());
149        }
150
151        // Build working state with empty slots for the input region.
152        // Cols 0..pt are passthroughs (occupied), cols pt..target.len() start empty.
153        let n_inputs_visible = input_wires.iter().filter(|w| w.is_some()).count();
154        let total_cols = pt + n_inputs_visible;
155        let mut slots: Vec<Option<VisibleWire>> = Vec::with_capacity(total_cols);
156        for w in &self.visible {
157            slots.push(Some(w.clone()));
158        }
159        while slots.len() < total_cols {
160            slots.push(None); // empty reserved slots
161        }
162
163        // Process inputs right to left. For each input:
164        // - Find the wire in `slots` (by outlet)
165        // - Compute its target column in the final layout
166        // - Render the routing line and update slots
167        for (ix, &inlet) in inputs.iter().enumerate().rev() {
168            let Some(ref input_wire) = input_wires[ix] else { continue };
169
170            let target_col = target
171                .iter()
172                .position(|w| w.outlet == inlet && w.successors.iter().any(|i| i.node == node))
173                .unwrap();
174
175            let cur_col =
176                match slots.iter().position(|s| s.as_ref().is_some_and(|w| w.outlet == inlet)) {
177                    Some(c) => c,
178                    None => continue,
179                };
180
181            let must_clone = input_wire.should_change_color; // proxy: cloned wires have this set
182
183            if cur_col == target_col && !must_clone {
184                continue;
185            }
186
187            // Render the routing line from cur_col to target_col.
188            let mut s = String::new();
189            let color = slots[cur_col].as_ref().unwrap().color;
190            let from = cur_col.min(target_col);
191            let to = cur_col.max(target_col);
192
193            // Leading verticals (cols before the leftmost endpoint)
194            for w in slots[..from].iter().flatten() {
195                let _ = write!(s, "{}", w.color.paint(VERTICAL));
196            }
197
198            if must_clone {
199                // Split: ┣ at cur_col, horizontals in between, ┓ at target_col
200                let _ = write!(s, "{}", color.paint(VERTICAL_RIGHT));
201            } else {
202                // Swap: ┗ at cur_col, horizontals in between, ┓ at target_col
203                let _ = write!(s, "{}", color.paint(UP_RIGHT));
204            }
205            for _ in from + 1..to {
206                let _ = write!(s, "{}", color.paint(HORIZONTAL));
207            }
208            let _ = write!(s, "{}", color.paint(DOWN_LEFT));
209
210            // Trailing verticals (cols after the rightmost endpoint)
211            for w in slots[to + 1..].iter().flatten() {
212                let _ = write!(s, "{}", w.color.paint(VERTICAL));
213            }
214
215            lines.push(s);
216
217            // Update slots: place the wire/clone at target_col
218            if must_clone {
219                // Original stays at cur_col, clone goes to target_col
220                slots[target_col] = Some(input_wire.clone());
221            } else {
222                // Move: remove from cur_col, place at target_col
223                slots[cur_col] = None;
224                slots[target_col] = Some(input_wire.clone());
225            }
226        }
227
228        // Set final state
229        self.visible = target;
230
231        lines.retain(|l: &String| !l.trim().is_empty());
232        Ok(lines)
233    }
234
235    pub fn draw_node_body(
236        &mut self,
237        model: &dyn Model,
238        node: usize,
239        opts: &DisplayParams,
240    ) -> TractResult<Vec<String>> {
241        let mut lines = vec![String::new()];
242        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
243        macro_rules! ln {
244            () => {
245                lines.push(String::new())
246            };
247        }
248
249        let inputs = model.node_inputs(node).to_vec();
250        let passthrough_count = self.passthrough_count(node);
251        let display = opts.konst || !model.node_const(node);
252
253        if display {
254            // Draw passthrough verticals
255            for w in &self.visible[..passthrough_count] {
256                p!("{}", w.color.paint(VERTICAL));
257            }
258
259            let node_output_count = model.node_output_count(node);
260
261            // Determine node color
262            self.latest_node_color = if !inputs.is_empty() && passthrough_count < self.visible.len()
263            {
264                let wire0 = &self.visible[passthrough_count];
265                if !wire0.should_change_color { wire0.color } else { self.next_color() }
266            } else {
267                self.next_color()
268            };
269
270            // Draw junction
271            match (inputs.len(), node_output_count) {
272                (0, 1) => {
273                    // Source node: use circled number if it's a model input
274                    let input_idx = model.input_outlets().iter().position(|o| o.node == node);
275                    let symbol = match input_idx {
276                        Some(i) => circled_input(i).to_string(),
277                        _ => DOWN_RIGHT.to_string(),
278                    };
279                    p!("{}", self.latest_node_color.paint(symbol));
280                }
281                (1, 0) => {
282                    p!("{}", self.latest_node_color.paint("╹"));
283                }
284                (u, d) => {
285                    p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
286                    for _ in 1..u.min(d) {
287                        p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
288                    }
289                    for _ in u..d {
290                        p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
291                    }
292                    for _ in d..u {
293                        p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
294                    }
295                }
296            }
297            ln!();
298        }
299
300        while lines.last().map(|s| s.trim()) == Some("") {
301            lines.pop();
302        }
303        Ok(lines)
304    }
305
306    pub fn draw_node_vfiller(&self, _model: &dyn Model, _node: usize) -> TractResult<String> {
307        Ok(self.render_filler())
308    }
309
310    pub fn draw_node_vsuffix(
311        &mut self,
312        model: &dyn Model,
313        node: usize,
314        opts: &DisplayParams,
315    ) -> TractResult<Vec<String>> {
316        // Mark node as visited now that its inputs have been consumed.
317        self.visited.insert(node);
318        let mut lines = vec![];
319        let passthrough_count = self.passthrough_count(node);
320        let node_output_count = model.node_output_count(node);
321
322        // Remove input wires (keep passthroughs)
323        self.visible.truncate(passthrough_count);
324
325        // Add output wires
326        for slot in 0..node_output_count {
327            let outlet = OutletId::new(node, slot);
328            let successors = model.outlet_successors(outlet).to_vec();
329            let color = if !opts.konst && model.node_const(node) {
330                // Const node: wire goes to hidden, not visible
331                self.hidden.push(HiddenWire { successors });
332                continue;
333            } else if slot == 0 {
334                self.latest_node_color
335            } else {
336                self.next_color()
337            };
338            self.visible.push(VisibleWire {
339                outlet,
340                color,
341                successors,
342                should_change_color: false,
343            });
344        }
345
346        // Mark model outputs with a circled number on a filler line.
347        let model_outputs = model.output_outlets();
348        let has_output_marker = self.visible.iter().any(|w| model_outputs.contains(&w.outlet));
349        if has_output_marker {
350            let mut s = String::new();
351            for w in &self.visible {
352                if model_outputs.contains(&w.outlet) {
353                    let output_idx = model_outputs.iter().position(|o| *o == w.outlet);
354                    let symbol = match output_idx {
355                        Some(i) => circled_output(i),
356                        _ => '●',
357                    };
358                    let _ = write!(s, "{}", w.color.paint(symbol.to_string()));
359                } else {
360                    let _ = write!(s, "{}", w.color.paint(VERTICAL));
361                }
362            }
363            lines.push(s);
364        }
365
366        // Remove wires with no successors
367        self.visible.retain(|w| !w.successors.is_empty());
368
369        lines.retain(|l: &String| !l.trim().is_empty());
370        Ok(lines)
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::display_params::DisplayParams;
378    use crate::model::Model;
379    use tract_core::ops::identity::Identity;
380    use tract_core::ops::math;
381
382    fn strip_ansi(s: &str) -> String {
383        let mut out = String::new();
384        let mut in_escape = false;
385        for c in s.chars() {
386            if in_escape {
387                if c == 'm' {
388                    in_escape = false;
389                }
390            } else if c == '\x1b' {
391                in_escape = true;
392            } else {
393                out.push(c);
394            }
395        }
396        out
397    }
398
399    fn draw_all(model: &dyn Model, ds: &mut DrawingState, node: usize) -> Vec<String> {
400        let opts = DisplayParams { konst: true, ..DisplayParams::default() };
401        let mut lines = vec![];
402        for l in ds.draw_node_vprefix(model, node, &opts).unwrap() {
403            lines.push(strip_ansi(&l));
404        }
405        for l in ds.draw_node_body(model, node, &opts).unwrap() {
406            lines.push(strip_ansi(&l));
407        }
408        for l in ds.draw_node_vsuffix(model, node, &opts).unwrap() {
409            lines.push(strip_ansi(&l));
410        }
411        lines.retain(|l| !l.trim().is_empty());
412        lines
413    }
414
415    /// Source → Identity (linear chain, no branching)
416    #[test]
417    fn linear_chain() -> TractResult<()> {
418        let mut model = TypedModel::default();
419        let s = model.add_source("s", f32::fact([1]))?;
420        let _id = model.wire_node("id", Identity, &[s])?[0];
421        model.auto_outputs()?;
422        let mut ds = DrawingState::default();
423        let lines0 = draw_all(&model, &mut ds, 0);
424        assert_eq!(lines0, vec!["⓪"]); // circled 0 (first model input)
425        let lines1 = draw_all(&model, &mut ds, 1);
426        assert_eq!(lines1[0], VERTICAL_RIGHT); // ┣ (1 in, 1 out)
427        assert!(lines1.len() == 2 && lines1[1] == "⓿"); // output marker
428        Ok(())
429    }
430
431    /// Source → Add(source, source) — fan-in from one source to two inputs
432    #[test]
433    fn fanin_from_one_source() -> TractResult<()> {
434        let mut model = TypedModel::default();
435        let s = model.add_source("s", f32::fact([1]))?;
436        let _add = model.wire_node("add", math::add(), &[s, s])?[0];
437        model.auto_outputs()?;
438        let mut ds = DrawingState::default();
439        let lines0 = draw_all(&model, &mut ds, 0);
440        assert_eq!(lines0, vec!["⓪"]); // circled 0 (first model input)
441        let lines1 = draw_all(&model, &mut ds, 1);
442        let joined = lines1.join("|");
443        assert!(
444            joined.contains(UP_HORIZONTAL), // ┻ (merge)
445            "Expected merge pattern, got: {lines1:?}"
446        );
447        Ok(())
448    }
449
450    /// Two sources → Add → two consumers (fork)
451    #[test]
452    fn fork_after_merge() -> TractResult<()> {
453        let mut model = TypedModel::default();
454        let a = model.add_source("a", f32::fact([1]))?;
455        let b = model.add_source("b", f32::fact([1]))?;
456        let add = model.wire_node("add", math::add(), &[a, b])?[0];
457        let _id1 = model.wire_node("id1", Identity, &[add])?[0];
458        let _id2 = model.wire_node("id2", Identity, &[add])?[0];
459        model.auto_outputs()?;
460        let mut ds = DrawingState::default();
461        draw_all(&model, &mut ds, 0); // source a
462        draw_all(&model, &mut ds, 1); // source b
463        let lines_add = draw_all(&model, &mut ds, 2); // add (2 inputs, 1 output)
464        let joined = lines_add.join("|");
465        assert!(
466            joined.contains(UP_HORIZONTAL), // ┻ (2 inputs merge)
467            "Expected merge in body, got: {lines_add:?}"
468        );
469        let lines_id1 = draw_all(&model, &mut ds, 3); // id1
470        assert!(!lines_id1.is_empty(), "id1 should render");
471        Ok(())
472    }
473
474    /// No blank lines in prefix output (regression for leading-empty-line bug)
475    #[test]
476    fn no_blank_prefix_lines() -> TractResult<()> {
477        let mut model = TypedModel::default();
478        let a = model.add_source("a", f32::fact([1]))?;
479        let b = model.add_source("b", f32::fact([1]))?;
480        let add = model.wire_node("add", math::add(), &[a, b])?[0];
481        let _id = model.wire_node("id", Identity, &[add])?[0];
482        model.auto_outputs()?;
483        let opts = DisplayParams { konst: true, ..DisplayParams::default() };
484        let mut ds = DrawingState::default();
485        let order = model.eval_order()?;
486        for &node in &order {
487            let prefix = ds.draw_node_vprefix(&model, node, &opts).unwrap();
488            for (i, l) in prefix.iter().enumerate() {
489                let stripped = strip_ansi(l);
490                assert!(
491                    !stripped.trim().is_empty() || i == prefix.len() - 1,
492                    "Blank line at position {i} in prefix for node {node}: {prefix:?}"
493                );
494            }
495            ds.draw_node_body(&model, node, &opts).unwrap();
496            ds.draw_node_vsuffix(&model, node, &opts).unwrap();
497        }
498        Ok(())
499    }
500
501    /// Filler width matches the number of visible wires (post-suffix state)
502    #[test]
503    fn filler_width_matches_visible() -> TractResult<()> {
504        let mut model = TypedModel::default();
505        let a = model.add_source("a", f32::fact([1]))?;
506        let b = model.add_source("b", f32::fact([1]))?;
507        let add = model.wire_node("add", math::add(), &[a, b])?[0];
508        let _id1 = model.wire_node("id1", Identity, &[add])?[0];
509        let _id2 = model.wire_node("id2", Identity, &[add])?[0];
510        model.auto_outputs()?;
511        let opts = DisplayParams { konst: true, ..DisplayParams::default() };
512        let mut ds = DrawingState::default();
513        let order = model.eval_order()?;
514        for &node in &order {
515            ds.draw_node_vprefix(&model, node, &opts).unwrap();
516            ds.draw_node_body(&model, node, &opts).unwrap();
517            ds.draw_node_vsuffix(&model, node, &opts).unwrap();
518            let filler = ds.draw_node_vfiller(&model, node).unwrap();
519            let filler_w = strip_ansi(&filler).chars().count();
520            let visible_count = ds.visible.len();
521            assert_eq!(
522                filler_w, visible_count,
523                "Filler width {filler_w} != visible wire count {visible_count} for node {node}"
524            );
525        }
526        Ok(())
527    }
528}