1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use box_drawing::heavy::*;
4use nu_ansi_term::{Color, Style};
5use std::fmt;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone)]
10pub struct Wire {
11 pub outlet: OutletId,
12 pub color: Option<Style>,
13 pub should_change_color: bool,
14 pub successors: Vec<InletId>,
15}
16
17impl fmt::Debug for Wire {
18 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
19 let s = format!("{:?} {:?}", self.outlet, self.successors);
20 if let Some(c) = self.color { write!(fmt, "{}", c.paint(s)) } else { write!(fmt, "{s}") }
21 }
22}
23
24#[derive(Clone, Default)]
25pub struct DrawingState {
26 pub current_color: Style,
27 pub latest_node_color: Style,
28 pub wires: Vec<Wire>,
29}
30
31impl DrawingState {
32 fn current_color(&self) -> Style {
33 self.current_color
34 }
35
36 fn next_color(&mut self) -> Style {
37 let colors = &[
38 Color::Red.normal(),
39 Color::Green.normal(),
40 Color::Yellow.normal(),
41 Color::Blue.normal(),
42 Color::Purple.normal(),
43 Color::Cyan.normal(),
44 Color::White.normal(),
45 Color::Red.bold(),
46 Color::Green.bold(),
47 Color::Yellow.bold(),
48 Color::Blue.bold(),
49 Color::Purple.bold(),
50 Color::Cyan.bold(),
51 Color::White.bold(),
52 ];
53 let color = colors
54 .iter()
55 .min_by_key(|&c| self.wires.iter().filter(|w| w.color == Some(*c)).count())
56 .unwrap();
57 self.current_color = *color;
58 *color
59 }
60
61 fn inputs_to_draw(&self, model: &dyn Model, node: usize) -> Vec<OutletId> {
62 model.node_inputs(node).to_vec()
63 }
64
65 fn passthrough_count(&self, node: usize) -> usize {
66 self.wires.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
67 }
68
69 pub fn draw_node_vprefix(
70 &mut self,
71 model: &dyn Model,
72 node: usize,
73 _opts: &DisplayParams,
74 ) -> TractResult<Vec<String>> {
75 let mut lines = vec![String::new()];
76 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
77 macro_rules! ln {
78 () => {
79 lines.push(String::new())
80 };
81 }
82 let passthrough_count = self.passthrough_count(node);
83 for (ix, &input) in model.node_inputs(node).iter().enumerate().rev() {
90 let wire = self.wires.iter().position(|o| o.outlet == input).unwrap();
91 let wanted = passthrough_count + ix;
92 if wire != wanted {
93 let little = wire.min(wanted);
94 let big = wire.max(wanted);
95 let moving = self.wires[little].clone();
96 let must_clone = moving.successors.iter().any(|i| i.node != node);
97 let offset = self
98 .wires
99 .iter()
100 .skip(little + 1)
101 .take(big - little)
102 .filter(|w| w.color.is_some())
103 .count()
104 + must_clone as usize;
105 if moving.color.is_some() && offset != 0 {
107 let color = moving.color.unwrap();
108 for w in &self.wires[0..little] {
109 if let Some(c) = w.color {
110 p!("{}", c.paint(VERTICAL));
111 }
112 }
113 p!("{}", color.paint(if must_clone { VERTICAL_RIGHT } else { UP_RIGHT }));
115 for _ in 0..offset - 1 {
116 p!("{}", color.paint(HORIZONTAL));
117 }
118 p!("{}", color.paint(DOWN_LEFT));
119 }
120 while self.wires.len() <= big {
121 self.wires.push(Wire { successors: vec![], ..self.wires[little] });
122 }
123 if must_clone {
124 self.wires[little].successors.retain(|&i| i != InletId::new(node, ix));
125 self.wires[big] = Wire {
126 successors: vec![InletId::new(node, ix)],
127 should_change_color: true,
128 ..self.wires[little]
129 };
130 } else {
131 for i in little..big {
132 self.wires.swap(i, i + 1);
133 }
134 }
135 if moving.color.is_some() {
136 if big < self.wires.len() {
137 for w in &self.wires[big + 1..] {
138 if let Some(c) = w.color {
139 p!("{}", c.paint(VERTICAL));
140 } else {
141 p!(" ");
142 }
143 }
144 }
145 ln!();
146 }
147 }
148 }
149 while lines.last().map(|s| s.trim()) == Some("") {
150 lines.pop();
151 }
152 Ok(lines)
153 }
154
155 pub fn draw_node_body(
156 &mut self,
157 model: &dyn Model,
158 node: usize,
159 opts: &DisplayParams,
160 ) -> TractResult<Vec<String>> {
161 let mut lines = vec![String::new()];
162 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
163 macro_rules! ln {
164 () => {
165 lines.push(String::new())
166 };
167 }
168 let inputs = self.inputs_to_draw(model, node);
169 let passthrough_count = self.passthrough_count(node);
170 let display = opts.konst || !model.node_const(node);
171 if display {
172 for wire in &self.wires[0..passthrough_count] {
173 if let Some(color) = wire.color {
174 p!("{}", color.paint(VERTICAL));
175 }
176 }
177 }
178 let node_output_count = model.node_output_count(node);
179 if display {
180 self.latest_node_color = if !inputs.is_empty() {
181 let wire0 = &self.wires[passthrough_count];
182 if wire0.color.is_some() && !wire0.should_change_color {
183 wire0.color.unwrap()
184 } else {
185 self.next_color()
186 }
187 } else {
188 self.next_color()
189 };
190 match (inputs.len(), node_output_count) {
191 (0, 1) => {
192 p!("{}", self.latest_node_color.paint(DOWN_RIGHT));
193 }
194 (1, 0) => {
195 p!("{}", self.latest_node_color.paint("╹"));
196 }
197 (u, d) => {
198 p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
199 for _ in 1..u.min(d) {
200 p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
201 }
202 for _ in u..d {
203 p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
204 }
205 for _ in d..u {
206 p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
207 }
208 }
209 }
210 ln!();
211 }
212 while lines.last().map(|s| s.trim()) == Some("") {
213 lines.pop();
214 }
215 Ok(lines)
216 }
217
218 pub fn draw_node_vfiller(&self, model: &dyn Model, node: usize) -> TractResult<String> {
219 let mut s = String::new();
220 for wire in &self.wires {
221 if let Some(color) = wire.color {
222 write!(&mut s, "{}", color.paint(VERTICAL))?;
223 }
224 }
225 for _ in self.wires.len()..model.node_output_count(node) {
226 write!(&mut s, " ")?;
227 }
228 Ok(s)
229 }
230
231 pub fn draw_node_vsuffix(
232 &mut self,
233 model: &dyn Model,
234 node: usize,
235 opts: &DisplayParams,
236 ) -> TractResult<Vec<String>> {
237 let mut lines = vec![];
238 let passthrough_count = self.passthrough_count(node);
239 let node_output_count = model.node_output_count(node);
240 let node_color = self
241 .wires
242 .get(passthrough_count)
243 .map(|w| w.color)
244 .unwrap_or_else(|| Some(self.current_color()));
245 self.wires.truncate(passthrough_count);
246 for slot in 0..node_output_count {
247 let outlet = OutletId::new(node, slot);
248 let successors = model.outlet_successors(outlet).to_vec();
249 let color = if !opts.konst && model.node_const(node) {
250 None
251 } else if slot == 0 && node_color.is_some() {
252 Some(self.latest_node_color)
253 } else {
254 Some(self.next_color())
255 };
256 self.wires.push(Wire { outlet, color, successors, should_change_color: false });
257 }
258 let wires_before = self.wires.clone();
259 self.wires.retain(|w| !w.successors.is_empty());
260 for (wanted_at, w) in self.wires.iter().enumerate() {
261 let is_at = wires_before.iter().position(|w2| w.outlet == w2.outlet).unwrap();
262 if wanted_at < is_at {
263 let mut s = String::new();
264 for w in 0..wanted_at {
265 if let Some(color) = self.wires[w].color {
266 write!(&mut s, "{}", color.paint(VERTICAL))?;
267 }
268 }
269 if let Some(color) = self.wires[wanted_at].color {
270 write!(&mut s, "{}", color.paint(DOWN_RIGHT))?;
271 for w in is_at + 1..wanted_at {
272 if self.wires[w].color.is_some() {
273 write!(&mut s, "{}", color.paint(HORIZONTAL))?;
274 }
275 }
276 write!(&mut s, "{}", color.paint(UP_LEFT))?;
277 for w in is_at..self.wires.len() {
278 if let Some(color) = self.wires[w].color {
279 write!(&mut s, "{}", color.paint(VERTICAL))?;
280 }
281 }
282 }
283 lines.push(s);
284 }
285 }
286 while lines.last().map(|s| s.trim()) == Some("") {
288 lines.pop();
289 }
290 Ok(lines)
291 }
292}