1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use box_drawing::heavy::*;
4use nu_ansi_term::{Color, Style};
5use std::fmt;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone)]
10pub struct Wire {
11 pub outlet: OutletId,
12 pub color: Option<Style>,
13 pub should_change_color: bool,
14 pub successors: Vec<InletId>,
15}
16
17impl fmt::Debug for Wire {
18 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
19 let s = format!("{:?} {:?}", self.outlet, self.successors);
20 if let Some(c) = self.color { write!(fmt, "{}", c.paint(s)) } else { write!(fmt, "{s}") }
21 }
22}
23
24#[derive(Clone, Default)]
25pub struct DrawingState {
26 pub current_color: Style,
27 pub latest_node_color: Style,
28 pub wires: Vec<Wire>,
29}
30
31impl DrawingState {
32 fn current_color(&self) -> Style {
33 self.current_color
34 }
35
36 fn next_color(&mut self) -> Style {
37 let colors = &[
38 Color::Red.normal(),
39 Color::Green.normal(),
40 Color::Yellow.normal(),
41 Color::Blue.normal(),
42 Color::Purple.normal(),
43 Color::Cyan.normal(),
44 Color::White.normal(),
45 Color::Red.bold(),
46 Color::Green.bold(),
47 Color::Yellow.bold(),
48 Color::Blue.bold(),
49 Color::Purple.bold(),
50 Color::Cyan.bold(),
51 Color::White.bold(),
52 ];
53 let color = colors
54 .iter()
55 .min_by_key(|&c| self.wires.iter().filter(|w| w.color == Some(*c)).count())
56 .unwrap();
57 self.current_color = *color;
58 *color
59 }
60
61 fn inputs_to_draw(&self, model: &dyn Model, node: usize) -> Vec<OutletId> {
62 model.node_inputs(node).to_vec()
63 }
64
65 fn passthrough_count(&self, node: usize) -> usize {
66 self.wires.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
67 }
68
69 pub fn draw_node_vprefix(
70 &mut self,
71 model: &dyn Model,
72 node: usize,
73 _opts: &DisplayParams,
74 ) -> TractResult<Vec<String>> {
75 let mut lines = vec![String::new()];
76 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
77 macro_rules! ln {
78 () => {
79 lines.push(String::new())
80 };
81 }
82 let passthrough_count = self.passthrough_count(node);
83 for (ix, &input) in model.node_inputs(node).iter().enumerate().rev() {
90 let wire = self.wires.iter().position(|o| o.outlet == input).unwrap();
91 let wanted = passthrough_count + ix;
92 if wire != wanted {
93 let little = wire.min(wanted);
94 let big = wire.max(wanted);
95 let moving = self.wires[little].clone();
96 let must_clone = moving.successors.iter().any(|i| i.node != node);
97 let offset = self
98 .wires
99 .iter()
100 .skip(little + 1)
101 .take(big - little)
102 .filter(|w| w.color.is_some())
103 .count()
104 + must_clone as usize;
105 #[allow(clippy::unnecessary_unwrap)]
107 if moving.color.is_some() && offset != 0 {
108 let color = moving.color.unwrap();
109 for w in &self.wires[0..little] {
110 if let Some(c) = w.color {
111 p!("{}", c.paint(VERTICAL));
112 }
113 }
114 p!("{}", color.paint(if must_clone { VERTICAL_RIGHT } else { UP_RIGHT }));
116 for _ in 0..offset - 1 {
117 p!("{}", color.paint(HORIZONTAL));
118 }
119 p!("{}", color.paint(DOWN_LEFT));
120 }
121 while self.wires.len() <= big {
122 self.wires.push(Wire { successors: vec![], ..self.wires[little] });
123 }
124 if must_clone {
125 self.wires[little].successors.retain(|&i| i != InletId::new(node, ix));
126 self.wires[big] = Wire {
127 successors: vec![InletId::new(node, ix)],
128 should_change_color: true,
129 ..self.wires[little]
130 };
131 } else {
132 for i in little..big {
133 self.wires.swap(i, i + 1);
134 }
135 }
136 if moving.color.is_some() {
137 if big < self.wires.len() {
138 for w in &self.wires[big + 1..] {
139 if let Some(c) = w.color {
140 p!("{}", c.paint(VERTICAL));
141 } else {
142 p!(" ");
143 }
144 }
145 }
146 ln!();
147 }
148 }
149 }
150 while lines.last().map(|s| s.trim()) == Some("") {
151 lines.pop();
152 }
153 Ok(lines)
154 }
155
156 pub fn draw_node_body(
157 &mut self,
158 model: &dyn Model,
159 node: usize,
160 opts: &DisplayParams,
161 ) -> TractResult<Vec<String>> {
162 let mut lines = vec![String::new()];
163 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
164 macro_rules! ln {
165 () => {
166 lines.push(String::new())
167 };
168 }
169 let inputs = self.inputs_to_draw(model, node);
170 let passthrough_count = self.passthrough_count(node);
171 let display = opts.konst || !model.node_const(node);
172 if display {
173 for wire in &self.wires[0..passthrough_count] {
174 if let Some(color) = wire.color {
175 p!("{}", color.paint(VERTICAL));
176 }
177 }
178 }
179 let node_output_count = model.node_output_count(node);
180 if display {
181 self.latest_node_color = if !inputs.is_empty() {
182 let wire0 = &self.wires[passthrough_count];
183 #[allow(clippy::unnecessary_unwrap)]
184 if wire0.color.is_some() && !wire0.should_change_color {
185 wire0.color.unwrap()
186 } else {
187 self.next_color()
188 }
189 } else {
190 self.next_color()
191 };
192 match (inputs.len(), node_output_count) {
193 (0, 1) => {
194 p!("{}", self.latest_node_color.paint(DOWN_RIGHT));
195 }
196 (1, 0) => {
197 p!("{}", self.latest_node_color.paint("╹"));
198 }
199 (u, d) => {
200 p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
201 for _ in 1..u.min(d) {
202 p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
203 }
204 for _ in u..d {
205 p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
206 }
207 for _ in d..u {
208 p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
209 }
210 }
211 }
212 ln!();
213 }
214 while lines.last().map(|s| s.trim()) == Some("") {
215 lines.pop();
216 }
217 Ok(lines)
218 }
219
220 pub fn draw_node_vfiller(&self, model: &dyn Model, node: usize) -> TractResult<String> {
221 let mut s = String::new();
222 for wire in &self.wires {
223 if let Some(color) = wire.color {
224 write!(&mut s, "{}", color.paint(VERTICAL))?;
225 }
226 }
227 for _ in self.wires.len()..model.node_output_count(node) {
228 write!(&mut s, " ")?;
229 }
230 Ok(s)
231 }
232
233 pub fn draw_node_vsuffix(
234 &mut self,
235 model: &dyn Model,
236 node: usize,
237 opts: &DisplayParams,
238 ) -> TractResult<Vec<String>> {
239 let mut lines = vec![];
240 let passthrough_count = self.passthrough_count(node);
241 let node_output_count = model.node_output_count(node);
242 let node_color = self
243 .wires
244 .get(passthrough_count)
245 .map(|w| w.color)
246 .unwrap_or_else(|| Some(self.current_color()));
247 self.wires.truncate(passthrough_count);
248 for slot in 0..node_output_count {
249 let outlet = OutletId::new(node, slot);
250 let successors = model.outlet_successors(outlet).to_vec();
251 let color = if !opts.konst && model.node_const(node) {
252 None
253 } else if slot == 0 && node_color.is_some() {
254 Some(self.latest_node_color)
255 } else {
256 Some(self.next_color())
257 };
258 self.wires.push(Wire { outlet, color, successors, should_change_color: false });
259 }
260 let wires_before = self.wires.clone();
261 self.wires.retain(|w| !w.successors.is_empty());
262 for (wanted_at, w) in self.wires.iter().enumerate() {
263 let is_at = wires_before.iter().position(|w2| w.outlet == w2.outlet).unwrap();
264 if wanted_at < is_at {
265 let mut s = String::new();
266 for w in 0..wanted_at {
267 if let Some(color) = self.wires[w].color {
268 write!(&mut s, "{}", color.paint(VERTICAL))?;
269 }
270 }
271 if let Some(color) = self.wires[wanted_at].color {
272 write!(&mut s, "{}", color.paint(DOWN_RIGHT))?;
273 for w in is_at + 1..wanted_at {
274 if self.wires[w].color.is_some() {
275 write!(&mut s, "{}", color.paint(HORIZONTAL))?;
276 }
277 }
278 write!(&mut s, "{}", color.paint(UP_LEFT))?;
279 for w in is_at..self.wires.len() {
280 if let Some(color) = self.wires[w].color {
281 write!(&mut s, "{}", color.paint(VERTICAL))?;
282 }
283 }
284 }
285 lines.push(s);
286 }
287 }
288 while lines.last().map(|s| s.trim()) == Some("") {
290 lines.pop();
291 }
292 Ok(lines)
293 }
294}