tract_libcli/
draw.rs

1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use nu_ansi_term::{Color, Style};
4use box_drawing::heavy::*;
5use std::fmt;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone)]
10pub struct Wire {
11    pub outlet: OutletId,
12    pub color: Option<Style>,
13    pub should_change_color: bool,
14    pub successors: Vec<InletId>,
15}
16
17impl fmt::Debug for Wire {
18    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
19        let s = format!("{:?} {:?}", self.outlet, self.successors);
20        if let Some(c) = self.color {
21            write!(fmt, "{}", c.paint(s))
22        } else {
23            write!(fmt, "{s}")
24        }
25    }
26}
27
28#[derive(Clone, Default)]
29pub struct DrawingState {
30    pub current_color: Style,
31    pub latest_node_color: Style,
32    pub wires: Vec<Wire>,
33}
34
35impl DrawingState {
36    fn current_color(&self) -> Style {
37        self.current_color
38    }
39
40    fn next_color(&mut self) -> Style {
41        let colors = &[
42            Color::Red.normal(),
43            Color::Green.normal(),
44            Color::Yellow.normal(),
45            Color::Blue.normal(),
46            Color::Purple.normal(),
47            Color::Cyan.normal(),
48            Color::White.normal(),
49            Color::Red.bold(),
50            Color::Green.bold(),
51            Color::Yellow.bold(),
52            Color::Blue.bold(),
53            Color::Purple.bold(),
54            Color::Cyan.bold(),
55            Color::White.bold(),
56        ];
57        let color = colors
58            .iter()
59            .min_by_key(|&c| self.wires.iter().filter(|w| w.color == Some(*c)).count())
60            .unwrap();
61        self.current_color = *color;
62        *color
63    }
64
65    fn inputs_to_draw(&self, model: &dyn Model, node: usize) -> Vec<OutletId> {
66        model.node_inputs(node).to_vec()
67    }
68
69    fn passthrough_count(&self, node: usize) -> usize {
70        self.wires.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
71    }
72
73    pub fn draw_node_vprefix(
74        &mut self,
75        model: &dyn Model,
76        node: usize,
77        _opts: &DisplayParams,
78    ) -> TractResult<Vec<String>> {
79        let mut lines = vec![String::new()];
80        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
81        macro_rules! ln {
82            () => {
83                lines.push(String::new())
84            };
85        }
86        let passthrough_count = self.passthrough_count(node);
87        /*
88        println!("\n{}", model.node_format(node));
89        for (ix, w) in self.wires.iter().enumerate() {
90            println!(" * {} {:?}", ix, w);
91        }
92        */
93        for (ix, &input) in model.node_inputs(node).iter().enumerate().rev() {
94            let wire = self.wires.iter().position(|o| o.outlet == input).unwrap();
95            let wanted = passthrough_count + ix;
96            if wire != wanted {
97                let little = wire.min(wanted);
98                let big = wire.max(wanted);
99                let moving = self.wires[little].clone();
100                let must_clone = moving.successors.iter().any(|i| i.node != node);
101                let offset = self
102                    .wires
103                    .iter()
104                    .skip(little + 1)
105                    .take(big - little)
106                    .filter(|w| w.color.is_some())
107                    .count()
108                    + must_clone as usize;
109                // println!("{}->{} (offset: {})", little, big, offset);
110                if moving.color.is_some() && offset != 0 {
111                    let color = moving.color.unwrap();
112                    for w in &self.wires[0..little] {
113                        if let Some(c) = w.color {
114                            p!("{}", c.paint(VERTICAL));
115                        }
116                    }
117                    // println!("offset: {}", offset);
118                    p!("{}", color.paint(if must_clone { VERTICAL_RIGHT } else { UP_RIGHT }));
119                    for _ in 0..offset - 1 {
120                        p!("{}", color.paint(HORIZONTAL));
121                    }
122                    p!("{}", color.paint(DOWN_LEFT));
123                }
124                while self.wires.len() <= big {
125                    self.wires.push(Wire { successors: vec![], ..self.wires[little] });
126                }
127                if must_clone {
128                    self.wires[little].successors.retain(|&i| i != InletId::new(node, ix));
129                    self.wires[big] = Wire {
130                        successors: vec![InletId::new(node, ix)],
131                        should_change_color: true,
132                        ..self.wires[little]
133                    };
134                } else {
135                    for i in little..big {
136                        self.wires.swap(i, i + 1);
137                    }
138                }
139                if moving.color.is_some() {
140                    if big < self.wires.len() {
141                        for w in &self.wires[big + 1..] {
142                            if let Some(c) = w.color {
143                                p!("{}", c.paint(VERTICAL));
144                            } else {
145                                p!(" ");
146                            }
147                        }
148                    }
149                    ln!();
150                }
151            }
152        }
153        while lines.last().map(|s| s.trim()) == Some("") {
154            lines.pop();
155        }
156        Ok(lines)
157    }
158
159    pub fn draw_node_body(
160        &mut self,
161        model: &dyn Model,
162        node: usize,
163        opts: &DisplayParams,
164    ) -> TractResult<Vec<String>> {
165        let mut lines = vec![String::new()];
166        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
167        macro_rules! ln {
168            () => {
169                lines.push(String::new())
170            };
171        }
172        let inputs = self.inputs_to_draw(model, node);
173        let passthrough_count = self.passthrough_count(node);
174        let display = opts.konst || !model.node_const(node);
175        if display {
176            for wire in &self.wires[0..passthrough_count] {
177                if let Some(color) = wire.color {
178                    p!("{}", color.paint(VERTICAL));
179                }
180            }
181        }
182        let node_output_count = model.node_output_count(node);
183        if display {
184            self.latest_node_color = if !inputs.is_empty() {
185                let wire0 = &self.wires[passthrough_count];
186                if wire0.color.is_some() && !wire0.should_change_color {
187                    wire0.color.unwrap()
188                } else {
189                    self.next_color()
190                }
191            } else {
192                self.next_color()
193            };
194            match (inputs.len(), node_output_count) {
195                (0, 1) => {
196                    p!("{}", self.latest_node_color.paint(DOWN_RIGHT));
197                }
198                (1, 0) => {
199                    p!("{}", self.latest_node_color.paint("╹"));
200                }
201                (u, d) => {
202                    p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
203                    for _ in 1..u.min(d) {
204                        p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
205                    }
206                    for _ in u..d {
207                        p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
208                    }
209                    for _ in d..u {
210                        p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
211                    }
212                }
213            }
214            ln!();
215        }
216        while lines.last().map(|s| s.trim()) == Some("") {
217            lines.pop();
218        }
219        Ok(lines)
220    }
221
222    pub fn draw_node_vfiller(&self, model: &dyn Model, node: usize) -> TractResult<String> {
223        let mut s = String::new();
224        for wire in &self.wires {
225            if let Some(color) = wire.color {
226                write!(&mut s, "{}", color.paint(VERTICAL))?;
227            }
228        }
229        for _ in self.wires.len()..model.node_output_count(node) {
230            write!(&mut s, " ")?;
231        }
232        Ok(s)
233    }
234
235    pub fn draw_node_vsuffix(
236        &mut self,
237        model: &dyn Model,
238        node: usize,
239        opts: &DisplayParams,
240    ) -> TractResult<Vec<String>> {
241        let mut lines = vec![];
242        let passthrough_count = self.passthrough_count(node);
243        let node_output_count = model.node_output_count(node);
244        let node_color = self
245            .wires
246            .get(passthrough_count)
247            .map(|w| w.color)
248            .unwrap_or_else(|| Some(self.current_color()));
249        self.wires.truncate(passthrough_count);
250        for slot in 0..node_output_count {
251            let outlet = OutletId::new(node, slot);
252            let successors = model.outlet_successors(outlet).to_vec();
253            let color = if !opts.konst && model.node_const(node) {
254                None
255            } else if slot == 0 && node_color.is_some() {
256                Some(self.latest_node_color)
257            } else {
258                Some(self.next_color())
259            };
260            self.wires.push(Wire { outlet, color, successors, should_change_color: false });
261        }
262        let wires_before = self.wires.clone();
263        self.wires.retain(|w| !w.successors.is_empty());
264        for (wanted_at, w) in self.wires.iter().enumerate() {
265            let is_at = wires_before.iter().position(|w2| w.outlet == w2.outlet).unwrap();
266            if wanted_at < is_at {
267                let mut s = String::new();
268                for w in 0..wanted_at {
269                    if let Some(color) = self.wires[w].color {
270                        write!(&mut s, "{}", color.paint(VERTICAL))?;
271                    }
272                }
273                if let Some(color) = self.wires[wanted_at].color {
274                    write!(&mut s, "{}", color.paint(DOWN_RIGHT))?;
275                    for w in is_at + 1..wanted_at {
276                        if self.wires[w].color.is_some() {
277                            write!(&mut s, "{}", color.paint(HORIZONTAL))?;
278                        }
279                    }
280                    write!(&mut s, "{}", color.paint(UP_LEFT))?;
281                    for w in is_at..self.wires.len() {
282                        if let Some(color) = self.wires[w].color {
283                            write!(&mut s, "{}", color.paint(VERTICAL))?;
284                        }
285                    }
286                }
287                lines.push(s);
288            }
289        }
290        // println!("{:?}", self.wires);
291        while lines.last().map(|s| s.trim()) == Some("") {
292            lines.pop();
293        }
294        Ok(lines)
295    }
296}