tract_libcli/
draw.rs

1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use box_drawing::heavy::*;
4use nu_ansi_term::{Color, Style};
5use std::fmt;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone)]
10pub struct Wire {
11    pub outlet: OutletId,
12    pub color: Option<Style>,
13    pub should_change_color: bool,
14    pub successors: Vec<InletId>,
15}
16
17impl fmt::Debug for Wire {
18    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
19        let s = format!("{:?} {:?}", self.outlet, self.successors);
20        if let Some(c) = self.color { write!(fmt, "{}", c.paint(s)) } else { write!(fmt, "{s}") }
21    }
22}
23
24#[derive(Clone, Default)]
25pub struct DrawingState {
26    pub current_color: Style,
27    pub latest_node_color: Style,
28    pub wires: Vec<Wire>,
29}
30
31impl DrawingState {
32    fn current_color(&self) -> Style {
33        self.current_color
34    }
35
36    fn next_color(&mut self) -> Style {
37        let colors = &[
38            Color::Red.normal(),
39            Color::Green.normal(),
40            Color::Yellow.normal(),
41            Color::Blue.normal(),
42            Color::Purple.normal(),
43            Color::Cyan.normal(),
44            Color::White.normal(),
45            Color::Red.bold(),
46            Color::Green.bold(),
47            Color::Yellow.bold(),
48            Color::Blue.bold(),
49            Color::Purple.bold(),
50            Color::Cyan.bold(),
51            Color::White.bold(),
52        ];
53        let color = colors
54            .iter()
55            .min_by_key(|&c| self.wires.iter().filter(|w| w.color == Some(*c)).count())
56            .unwrap();
57        self.current_color = *color;
58        *color
59    }
60
61    fn inputs_to_draw(&self, model: &dyn Model, node: usize) -> Vec<OutletId> {
62        model.node_inputs(node).to_vec()
63    }
64
65    fn passthrough_count(&self, node: usize) -> usize {
66        self.wires.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
67    }
68
69    pub fn draw_node_vprefix(
70        &mut self,
71        model: &dyn Model,
72        node: usize,
73        _opts: &DisplayParams,
74    ) -> TractResult<Vec<String>> {
75        let mut lines = vec![String::new()];
76        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
77        macro_rules! ln {
78            () => {
79                lines.push(String::new())
80            };
81        }
82        let passthrough_count = self.passthrough_count(node);
83        /*
84        println!("\n{}", model.node_format(node));
85        for (ix, w) in self.wires.iter().enumerate() {
86            println!(" * {} {:?}", ix, w);
87        }
88        */
89        for (ix, &input) in model.node_inputs(node).iter().enumerate().rev() {
90            let wire = self.wires.iter().position(|o| o.outlet == input).unwrap();
91            let wanted = passthrough_count + ix;
92            if wire != wanted {
93                let little = wire.min(wanted);
94                let big = wire.max(wanted);
95                let moving = self.wires[little].clone();
96                let must_clone = moving.successors.iter().any(|i| i.node != node);
97                let offset = self
98                    .wires
99                    .iter()
100                    .skip(little + 1)
101                    .take(big - little)
102                    .filter(|w| w.color.is_some())
103                    .count()
104                    + must_clone as usize;
105                // println!("{}->{} (offset: {})", little, big, offset);
106                if moving.color.is_some() && offset != 0 {
107                    let color = moving.color.unwrap();
108                    for w in &self.wires[0..little] {
109                        if let Some(c) = w.color {
110                            p!("{}", c.paint(VERTICAL));
111                        }
112                    }
113                    // println!("offset: {}", offset);
114                    p!("{}", color.paint(if must_clone { VERTICAL_RIGHT } else { UP_RIGHT }));
115                    for _ in 0..offset - 1 {
116                        p!("{}", color.paint(HORIZONTAL));
117                    }
118                    p!("{}", color.paint(DOWN_LEFT));
119                }
120                while self.wires.len() <= big {
121                    self.wires.push(Wire { successors: vec![], ..self.wires[little] });
122                }
123                if must_clone {
124                    self.wires[little].successors.retain(|&i| i != InletId::new(node, ix));
125                    self.wires[big] = Wire {
126                        successors: vec![InletId::new(node, ix)],
127                        should_change_color: true,
128                        ..self.wires[little]
129                    };
130                } else {
131                    for i in little..big {
132                        self.wires.swap(i, i + 1);
133                    }
134                }
135                if moving.color.is_some() {
136                    if big < self.wires.len() {
137                        for w in &self.wires[big + 1..] {
138                            if let Some(c) = w.color {
139                                p!("{}", c.paint(VERTICAL));
140                            } else {
141                                p!(" ");
142                            }
143                        }
144                    }
145                    ln!();
146                }
147            }
148        }
149        while lines.last().map(|s| s.trim()) == Some("") {
150            lines.pop();
151        }
152        Ok(lines)
153    }
154
155    pub fn draw_node_body(
156        &mut self,
157        model: &dyn Model,
158        node: usize,
159        opts: &DisplayParams,
160    ) -> TractResult<Vec<String>> {
161        let mut lines = vec![String::new()];
162        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
163        macro_rules! ln {
164            () => {
165                lines.push(String::new())
166            };
167        }
168        let inputs = self.inputs_to_draw(model, node);
169        let passthrough_count = self.passthrough_count(node);
170        let display = opts.konst || !model.node_const(node);
171        if display {
172            for wire in &self.wires[0..passthrough_count] {
173                if let Some(color) = wire.color {
174                    p!("{}", color.paint(VERTICAL));
175                }
176            }
177        }
178        let node_output_count = model.node_output_count(node);
179        if display {
180            self.latest_node_color = if !inputs.is_empty() {
181                let wire0 = &self.wires[passthrough_count];
182                if wire0.color.is_some() && !wire0.should_change_color {
183                    wire0.color.unwrap()
184                } else {
185                    self.next_color()
186                }
187            } else {
188                self.next_color()
189            };
190            match (inputs.len(), node_output_count) {
191                (0, 1) => {
192                    p!("{}", self.latest_node_color.paint(DOWN_RIGHT));
193                }
194                (1, 0) => {
195                    p!("{}", self.latest_node_color.paint("╹"));
196                }
197                (u, d) => {
198                    p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
199                    for _ in 1..u.min(d) {
200                        p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
201                    }
202                    for _ in u..d {
203                        p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
204                    }
205                    for _ in d..u {
206                        p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
207                    }
208                }
209            }
210            ln!();
211        }
212        while lines.last().map(|s| s.trim()) == Some("") {
213            lines.pop();
214        }
215        Ok(lines)
216    }
217
218    pub fn draw_node_vfiller(&self, model: &dyn Model, node: usize) -> TractResult<String> {
219        let mut s = String::new();
220        for wire in &self.wires {
221            if let Some(color) = wire.color {
222                write!(&mut s, "{}", color.paint(VERTICAL))?;
223            }
224        }
225        for _ in self.wires.len()..model.node_output_count(node) {
226            write!(&mut s, " ")?;
227        }
228        Ok(s)
229    }
230
231    pub fn draw_node_vsuffix(
232        &mut self,
233        model: &dyn Model,
234        node: usize,
235        opts: &DisplayParams,
236    ) -> TractResult<Vec<String>> {
237        let mut lines = vec![];
238        let passthrough_count = self.passthrough_count(node);
239        let node_output_count = model.node_output_count(node);
240        let node_color = self
241            .wires
242            .get(passthrough_count)
243            .map(|w| w.color)
244            .unwrap_or_else(|| Some(self.current_color()));
245        self.wires.truncate(passthrough_count);
246        for slot in 0..node_output_count {
247            let outlet = OutletId::new(node, slot);
248            let successors = model.outlet_successors(outlet).to_vec();
249            let color = if !opts.konst && model.node_const(node) {
250                None
251            } else if slot == 0 && node_color.is_some() {
252                Some(self.latest_node_color)
253            } else {
254                Some(self.next_color())
255            };
256            self.wires.push(Wire { outlet, color, successors, should_change_color: false });
257        }
258        let wires_before = self.wires.clone();
259        self.wires.retain(|w| !w.successors.is_empty());
260        for (wanted_at, w) in self.wires.iter().enumerate() {
261            let is_at = wires_before.iter().position(|w2| w.outlet == w2.outlet).unwrap();
262            if wanted_at < is_at {
263                let mut s = String::new();
264                for w in 0..wanted_at {
265                    if let Some(color) = self.wires[w].color {
266                        write!(&mut s, "{}", color.paint(VERTICAL))?;
267                    }
268                }
269                if let Some(color) = self.wires[wanted_at].color {
270                    write!(&mut s, "{}", color.paint(DOWN_RIGHT))?;
271                    for w in is_at + 1..wanted_at {
272                        if self.wires[w].color.is_some() {
273                            write!(&mut s, "{}", color.paint(HORIZONTAL))?;
274                        }
275                    }
276                    write!(&mut s, "{}", color.paint(UP_LEFT))?;
277                    for w in is_at..self.wires.len() {
278                        if let Some(color) = self.wires[w].color {
279                            write!(&mut s, "{}", color.paint(VERTICAL))?;
280                        }
281                    }
282                }
283                lines.push(s);
284            }
285        }
286        // println!("{:?}", self.wires);
287        while lines.last().map(|s| s.trim()) == Some("") {
288            lines.pop();
289        }
290        Ok(lines)
291    }
292}