Skip to main content

tract_libcli/
terminal.rs

1use std::time::Duration;
2
3use crate::annotations::*;
4use crate::display_params::*;
5use crate::draw::DrawingState;
6use crate::model::Model;
7use nu_ansi_term::AnsiString;
8use nu_ansi_term::Color::*;
9#[allow(unused_imports)]
10use std::convert::TryFrom;
11use tract_core::internal::*;
12use tract_core::num_traits::AsPrimitive;
13use tract_itertools::Itertools;
14
15pub fn render(
16    model: &dyn Model,
17    annotations: &Annotations,
18    options: &DisplayParams,
19) -> TractResult<()> {
20    if options.quiet {
21        return Ok(());
22    }
23    render_prefixed(model, "", &[], annotations, options)?;
24    if !model.properties().is_empty() {
25        println!("{}", White.bold().paint("# Properties"));
26    }
27    for (k, v) in model.properties().iter().sorted_by_key(|(k, _)| k.to_string()) {
28        println!("* {}: {:?}", White.paint(k), v)
29    }
30    let symbols = model.symbols();
31    if !symbols.all_assertions().is_empty() {
32        println!("{}", White.bold().paint("# Assertions"));
33        for a in symbols.all_assertions() {
34            println!(" * {a}");
35        }
36    }
37    for (ix, scenario) in symbols.all_scenarios().into_iter().enumerate() {
38        if ix == 0 {
39            println!("{}", White.bold().paint("# Scenarios"));
40        }
41        for a in scenario.1 {
42            println!(" * {}: {}", scenario.0, a);
43        }
44    }
45    Ok(())
46}
47
48pub fn render_node(
49    model: &dyn Model,
50    node_id: usize,
51    annotations: &Annotations,
52    options: &DisplayParams,
53) -> TractResult<()> {
54    render_node_prefixed(model, "", &[], node_id, None, annotations, options)
55}
56
57fn render_prefixed(
58    model: &dyn Model,
59    prefix: &str,
60    scope: &[(usize, String)],
61    annotations: &Annotations,
62    options: &DisplayParams,
63) -> TractResult<()> {
64    let mut drawing_state =
65        if options.should_draw() { Some(DrawingState::default()) } else { None };
66    let node_ids = options.order(model)?;
67    for node in node_ids {
68        if options.filter(model, scope, node)? {
69            render_node_prefixed(
70                model,
71                prefix,
72                scope,
73                node,
74                drawing_state.as_mut(),
75                annotations,
76                options,
77            )?
78        } else if let Some(ref mut ds) = drawing_state {
79            let _prefix = ds.draw_node_vprefix(model, node, options)?;
80            let _body = ds.draw_node_body(model, node, options)?;
81            let _suffix = ds.draw_node_vsuffix(model, node, options)?;
82        }
83    }
84    Ok(())
85}
86
87pub fn si_prefix(v: impl AsPrimitive<f64>, unit: &str) -> String {
88    radical_prefix(v, unit, 1000, "")
89}
90
91pub fn pow2_prefix(v: impl AsPrimitive<f64>, unit: &str) -> String {
92    radical_prefix(v, unit, 1024, "i")
93}
94
95pub fn radical_prefix(
96    v: impl AsPrimitive<f64>,
97    unit: &str,
98    radical: usize,
99    radical_prefix: &str,
100) -> String {
101    let v: f64 = v.as_();
102    let radical = radical as f64;
103    let radical3 = radical.powi(3);
104    let radical2 = radical.powi(2);
105    if v > radical3 {
106        format!("{:7.3} G{}{}", v / radical3, radical_prefix, unit)
107    } else if v > 1e6 {
108        format!("{:7.3} M{}{}", v / radical2, radical_prefix, unit)
109    } else if v > 1e3 {
110        format!("{:7.3} k{}{}", v / radical, radical_prefix, unit)
111    } else {
112        format!("{v:7.3}  {unit}")
113    }
114}
115
116fn render_node_prefixed(
117    model: &dyn Model,
118    prefix: &str,
119    scope: &[(usize, String)],
120    node_id: usize,
121    mut drawing_state: Option<&mut DrawingState>,
122    annotations: &Annotations,
123    options: &DisplayParams,
124) -> TractResult<()> {
125    let qid = NodeQId(scope.into(), node_id);
126    let tags = annotations.tags.get(&qid).cloned().unwrap_or_default();
127    let name_color = tags.style.unwrap_or_else(|| White.into());
128    let node_name = model.node_name(node_id);
129    let node_op_name = model.node_op_name(node_id);
130    let profile_column_pad = format!("{:>1$}", "", options.profile as usize * 20);
131    let cost_column_pad = format!("{:>1$}", "", options.cost as usize * 25);
132    let mem_padding = if annotations.memory_summary.is_some() { 15 } else { 30 };
133    let tmp_mem_usage_column_pad =
134        format!("{:>1$}", "", options.tmp_mem_usage as usize * mem_padding);
135    let flops_column_pad = format!("{:>1$}", "", (options.profile && options.cost) as usize * 20);
136
137    if let Some(ds) = &mut drawing_state {
138        for l in ds.draw_node_vprefix(model, node_id, options)? {
139            println!(
140                "{cost_column_pad}{profile_column_pad}{flops_column_pad}{tmp_mem_usage_column_pad}{prefix}{l} "
141            );
142        }
143    }
144
145    // profile column
146    let mut profile_column = tags.profile.map(|measure| {
147        let profile_summary = annotations.profile_summary.as_ref().unwrap();
148        let use_micros = profile_summary.sum < Duration::from_millis(1);
149        let ratio = measure.as_secs_f64() / profile_summary.sum.as_secs_f64();
150        let ratio_for_color = measure.as_secs_f64() / profile_summary.max.as_secs_f64();
151        let color = colorous::RED_YELLOW_GREEN.eval_continuous(1.0 - ratio_for_color);
152        let color = nu_ansi_term::Color::Rgb(color.r, color.g, color.b);
153        let label = format!(
154            "{:7.3} {}s/i {}  ",
155            measure.as_secs_f64() * if use_micros { 1e6 } else { 1e3 },
156            if use_micros { "ยต" } else { "m" },
157            color.bold().paint(format!("{:>4.1}%", ratio * 100.0))
158        );
159        std::iter::once(label)
160    });
161
162    // cost column
163    let mut cost_column = if options.cost {
164        Some(
165            tags.cost
166                .iter()
167                .map(|c| {
168                    let key = format!("{:?}", c.0);
169                    let value = render_tdim(&c.1);
170                    let value_visible_len = c.1.to_string().len();
171                    let padding = 24usize.saturating_sub(value_visible_len + key.len());
172                    key + &*std::iter::repeat_n(' ', padding).join("") + &value.to_string() + " "
173                })
174                .peekable(),
175        )
176    } else {
177        None
178    };
179
180    // flops column
181    let mut flops_column = if options.profile && options.cost {
182        let timing: f64 = tags.profile.as_ref().map(|d| d.as_secs_f64()).unwrap_or(0.0);
183        let flops_column_pad = flops_column_pad.clone();
184        let it = tags.cost.iter().map(move |c| {
185            if c.0.is_compute() {
186                let flops = c.1.to_usize().unwrap_or(0) as f64 / timing;
187                let unpadded = si_prefix(flops, "F/s");
188                format!("{:>1$} ", unpadded, 19)
189            } else {
190                flops_column_pad.clone()
191            }
192        });
193        Some(it)
194    } else {
195        None
196    };
197
198    // tmp_mem_usage column
199    let mut tmp_mem_usage_column = if options.tmp_mem_usage {
200        let it = tags.tmp_mem_usage.iter().map(move |mem| {
201            let unpadded = if let Ok(mem_size) = mem.to_usize() {
202                pow2_prefix(mem_size, "B")
203            } else {
204                format!("{mem:.3} B")
205            };
206            format!("{:>1$} ", unpadded, mem_padding - 1)
207        });
208        Some(it)
209    } else {
210        None
211    };
212
213    // drawing column
214    let mut drawing_lines: Box<dyn Iterator<Item = String>> =
215        if let Some(ds) = drawing_state.as_mut() {
216            let body = ds.draw_node_body(model, node_id, options)?;
217            let suffix = ds.draw_node_vsuffix(model, node_id, options)?;
218            let filler = ds.draw_node_vfiller(model, node_id)?;
219            Box::new(body.into_iter().chain(suffix).chain(std::iter::repeat(filler)))
220        } else {
221            Box::new(std::iter::repeat(cost_column_pad.clone()))
222        };
223
224    macro_rules! prefix {
225        () => {
226            let cost = cost_column
227                .as_mut()
228                .map(|it| it.next().unwrap_or_else(|| cost_column_pad.to_string()))
229                .unwrap_or("".to_string());
230            let profile = profile_column
231                .as_mut()
232                .map(|it| it.next().unwrap_or_else(|| profile_column_pad.to_string()))
233                .unwrap_or("".to_string());
234            let flops = flops_column
235                .as_mut()
236                .map(|it| it.next().unwrap_or_else(|| flops_column_pad.to_string()))
237                .unwrap_or("".to_string());
238            let tmp_mem_usage = tmp_mem_usage_column
239                .as_mut()
240                .map(|it| it.next().unwrap_or_else(|| tmp_mem_usage_column_pad.to_string()))
241                .unwrap_or("".to_string());
242            print!(
243                "{}{}{}{}{}{} ",
244                profile,
245                cost,
246                flops,
247                tmp_mem_usage,
248                prefix,
249                drawing_lines.next().unwrap(),
250            )
251        };
252    }
253
254    let have_accel_profiling =
255        annotations.tags.iter().any(|(_, tag)| tag.accelerator_profile.is_some());
256    let is_cpu_fallback = have_accel_profiling
257        && tags.accelerator_profile.unwrap_or_default() == Duration::default()
258        && tags.profile.unwrap_or_default() > Duration::default();
259    let op_color = if node_name == "UnimplementedOp" {
260        Red.bold()
261    } else if is_cpu_fallback {
262        Yellow.bold()
263    } else {
264        Blue.bold()
265    };
266
267    prefix!();
268    println!(
269        "{} {} {}",
270        White.bold().paint(format!("{node_id}")),
271        op_color.paint(node_op_name),
272        name_color.italic().paint(node_name)
273    );
274    for label in tags.labels.iter() {
275        prefix!();
276        println!("  * {label}");
277    }
278    if let Io::Long = options.io {
279        for (ix, i) in model.node_inputs(node_id).iter().enumerate() {
280            let star = if ix == 0 { '*' } else { ' ' };
281            prefix!();
282            println!(
283                "  {} input fact  #{}: {} {}",
284                star,
285                ix,
286                White.bold().paint(format!("{i:?}")),
287                model.outlet_fact_format(*i),
288            );
289        }
290        for slot in 0..model.node_output_count(node_id) {
291            let star = if slot == 0 { '*' } else { ' ' };
292            let outlet = OutletId::new(node_id, slot);
293            let mut model_io = vec![];
294            for (ix, _) in model.input_outlets().iter().enumerate().filter(|(_, o)| **o == outlet) {
295                model_io.push(Cyan.bold().paint(format!("MODEL INPUT #{ix}")).to_string());
296            }
297            if let Some(t) = &tags.model_input {
298                model_io.push(t.to_string());
299            }
300            for (ix, _) in model.output_outlets().iter().enumerate().filter(|(_, o)| **o == outlet)
301            {
302                model_io.push(Yellow.bold().paint(format!("MODEL OUTPUT #{ix}")).to_string());
303            }
304            if let Some(t) = &tags.model_output {
305                model_io.push(t.to_string());
306            }
307            let successors = model.outlet_successors(outlet);
308            prefix!();
309            let mut axes =
310                tags.outlet_axes.get(slot).map(|s| s.join(",")).unwrap_or_else(|| "".to_string());
311            if !axes.is_empty() {
312                axes.push(' ')
313            }
314            println!(
315                "  {} output fact #{}: {}{} {} {} {}",
316                star,
317                slot,
318                Green.bold().italic().paint(axes),
319                model.outlet_fact_format(outlet),
320                White.bold().paint(successors.iter().map(|s| format!("{s:?}")).join(" ")),
321                model_io.join(", "),
322                Blue.bold().italic().paint(
323                    tags.outlet_labels
324                        .get(slot)
325                        .map(|s| s.join(","))
326                        .unwrap_or_else(|| "".to_string())
327                )
328            );
329            if options.outlet_labels {
330                if let Some(label) = model.outlet_label(OutletId::new(node_id, slot)) {
331                    prefix!();
332                    println!("            {} ", White.italic().paint(label));
333                }
334            }
335        }
336    }
337    if options.info {
338        for info in model.node_op(node_id).info()? {
339            prefix!();
340            println!("  * {info}");
341        }
342    }
343    if options.invariants {
344        if let Some(typed) = model.downcast_ref::<TypedModel>() {
345            let node = typed.node(node_id);
346            let (inputs, outputs) = typed.node_facts(node.id)?;
347            let axes_mapping = node.op().as_typed().unwrap().axes_mapping(&inputs, &outputs)?;
348            prefix!();
349            println!("  * {axes_mapping}");
350        }
351    }
352    if options.debug_op {
353        prefix!();
354        println!("  * {:?}", model.node_op(node_id));
355    }
356    for section in tags.sections {
357        if section.is_empty() {
358            continue;
359        }
360        prefix!();
361        println!("  * {}", section[0]);
362        for s in &section[1..] {
363            prefix!();
364            println!("    {s}");
365        }
366    }
367
368    if !options.folded {
369        for (label, sub) in model.nested_models(node_id) {
370            let prefix = drawing_lines.next().unwrap();
371            let mut scope: TVec<_> = scope.into();
372            scope.push((node_id, label));
373            let scope_prefix = scope.iter().map(|(_, p)| p).join("|");
374            render_prefixed(
375                sub,
376                &format!("{prefix} [{scope_prefix}] "),
377                &scope,
378                annotations,
379                options,
380            )?
381        }
382    }
383    if let Io::Short = options.io {
384        let same = !model.node_inputs(node_id).is_empty()
385            && model.node_output_count(node_id) == 1
386            && model.outlet_fact_format(node_id.into())
387                == model.outlet_fact_format(model.node_inputs(node_id)[0]);
388        if !same || model.output_outlets().iter().any(|o| o.node == node_id) {
389            let style = drawing_state.map(|s| s.last_wire_color()).unwrap_or_else(|| White.into());
390            for ix in 0..model.node_output_count(node_id) {
391                prefix!();
392                println!(
393                    "  {}{}{} {}",
394                    style.paint(box_drawing::heavy::HORIZONTAL),
395                    style.paint(box_drawing::heavy::HORIZONTAL),
396                    style.paint(box_drawing::heavy::HORIZONTAL),
397                    model.outlet_fact_format((node_id, ix).into())
398                );
399            }
400        }
401    }
402
403    while cost_column.as_mut().map(|cost| cost.peek().is_some()).unwrap_or(false) {
404        prefix!();
405        println!();
406    }
407
408    Ok(())
409}
410
411pub fn render_summaries(
412    model: &dyn Model,
413    annotations: &Annotations,
414    options: &DisplayParams,
415) -> TractResult<()> {
416    let total = annotations.tags.values().sum::<NodeTags>();
417
418    if options.tmp_mem_usage {
419        if let Some(summary) = &annotations.memory_summary {
420            println!("{}", White.bold().paint("Memory summary"));
421            println!(" * Peak flushable memory: {}", pow2_prefix(summary.max, "B"));
422        }
423    }
424    if options.cost {
425        println!("{}", White.bold().paint("Cost summary"));
426        for (c, i) in &total.cost {
427            println!(" * {:?}: {}", c, render_tdim(i));
428        }
429    }
430
431    if options.profile {
432        let summary = annotations.profile_summary.as_ref().unwrap();
433
434        let have_accel_profiling =
435            annotations.tags.iter().any(|(_, tag)| tag.accelerator_profile.is_some());
436        println!(
437            "{}{}{}",
438            White.bold().paint(format!("{:<43}", "Most time consuming operations")),
439            White.bold().paint(format!("{:<17}", "CPU")),
440            White.bold().paint(if have_accel_profiling { "Accelerator" } else { "" }),
441        );
442
443        for (op, (cpu_dur, accel_dur, n)) in annotations
444            .tags
445            .iter()
446            .map(|(k, v)| {
447                (
448                    k.model(model).unwrap().node_op_name(k.1),
449                    (v.profile.unwrap_or_default(), v.accelerator_profile.unwrap_or_default()),
450                )
451            })
452            .sorted_by_key(|a| a.0.to_string())
453            .chunk_by(|(n, _)| n.clone())
454            .into_iter()
455            .map(|(a, group)| {
456                (
457                    a,
458                    group.into_iter().fold(
459                        (Duration::default(), Duration::default(), 0),
460                        |(accu, accel_accu, n), d| (accu + d.1.0, accel_accu + d.1.1, n + 1),
461                    ),
462                )
463            })
464            .sorted_by_key(|(_, d)| if have_accel_profiling { d.1 } else { d.0 })
465            .rev()
466        {
467            let is_cpu_fallback = have_accel_profiling
468                && accel_dur == Duration::default()
469                && cpu_dur > Duration::default();
470            let op_color = if is_cpu_fallback { Yellow.bold() } else { Blue.bold() };
471            println!(
472                " * {} {:3} nodes: {}  {}",
473                op_color.paint(format!("{op:22}")),
474                n,
475                dur_avg_ratio(cpu_dur, summary.sum),
476                if have_accel_profiling {
477                    dur_avg_ratio(accel_dur, summary.accel_sum)
478                } else {
479                    "".to_string()
480                }
481            );
482        }
483
484        println!("{}", White.bold().paint("By prefix"));
485        fn prefixes_for(s: &str) -> impl Iterator<Item = String> + '_ {
486            use tract_itertools::*;
487            let split = s.split('.').count();
488            (0..split).map(move |n| s.split('.').take(n).join("."))
489        }
490        let all_prefixes = annotations
491            .tags
492            .keys()
493            .flat_map(|id| prefixes_for(id.model(model).unwrap().node_name(id.1)))
494            .filter(|s| !s.is_empty())
495            .sorted()
496            .unique()
497            .collect::<Vec<String>>();
498
499        for prefix in &all_prefixes {
500            let sum = annotations
501                .tags
502                .iter()
503                .filter(|(k, _v)| k.model(model).unwrap().node_name(k.1).starts_with(prefix))
504                .map(|(_k, v)| v)
505                .sum::<NodeTags>();
506
507            let profiler =
508                if !have_accel_profiling { sum.profile } else { sum.accelerator_profile };
509            if profiler.unwrap_or_default().as_secs_f64() / summary.entire.as_secs_f64() < 0.01 {
510                continue;
511            }
512            print!("{}    ", dur_avg_ratio(profiler.unwrap_or_default(), summary.sum));
513
514            for _ in prefix.chars().filter(|c| *c == '.') {
515                print!("   ");
516            }
517            println!("{prefix}");
518        }
519
520        println!(
521            "Not accounted by ops: {}",
522            dur_avg_ratio(summary.entire - summary.sum.min(summary.entire), summary.entire)
523        );
524
525        if have_accel_profiling {
526            println!(
527                "(Total CPU Op time - Total Accelerator Op time): {}",
528                dur_avg_ratio(summary.sum - summary.accel_sum.min(summary.sum), summary.entire)
529            );
530        }
531        println!("Entire network performance: {}", dur_avg(summary.entire));
532    }
533
534    Ok(())
535}
536
537pub fn render_summary(model: &dyn Model, annotations: &Annotations) -> TractResult<()> {
538    if !model.properties().is_empty() {
539        println!("{}", White.bold().paint("# Properties"));
540        for (k, v) in model.properties().iter().sorted_by_key(|(k, _)| k.to_string()) {
541            println!("* {}: {:?}", White.paint(k), v);
542        }
543    }
544    println!("{}", White.bold().paint("# Inputs"));
545    for (ix, input) in model.input_outlets().iter().enumerate() {
546        let name = model.node_name(input.node);
547        let fact = model.outlet_typedfact(*input)?;
548        let symbol = crate::draw::circled_input(ix);
549        println!("  {symbol} {name}: {fact:?}");
550    }
551    println!("{}", White.bold().paint("# Outputs"));
552    for (ix, output) in model.output_outlets().iter().enumerate() {
553        let name = model.node_name(output.node);
554        let fact = model.outlet_typedfact(*output)?;
555        let symbol = crate::draw::circled_output(ix);
556        println!("  {symbol} {name}: {fact:?}");
557    }
558    let mut op_counts: HashMap<StaticName, usize> = HashMap::default();
559    let mut op_costs: HashMap<StaticName, Vec<(Cost, TDim)>> = HashMap::default();
560    for id in 0..model.nodes_len() {
561        let op_name = model.node_op_name(id);
562        *op_counts.entry(op_name.clone()).or_default() += 1;
563        if let Some(tags) = annotations.tags.get(&NodeQId(tvec!(), id)) {
564            let costs = op_costs.entry(op_name).or_default();
565            for (cost_kind, value) in &tags.cost {
566                if let Some(existing) = costs.iter_mut().find(|(k, _)| k == cost_kind) {
567                    existing.1 = existing.1.clone() + value;
568                } else {
569                    costs.push((cost_kind.clone(), value.clone()));
570                }
571            }
572        }
573    }
574    let total = annotations.tags.values().sum::<NodeTags>();
575    let total_cost_str = total
576        .cost
577        .iter()
578        .filter(|(k, _)| k.is_compute())
579        .map(|(kind, val)| format!("{kind:?}: {}", render_tdim(val)))
580        .join(", ");
581    let all_costs_concrete = op_costs
582        .values()
583        .all(|costs| costs.iter().filter(|(k, _)| k.is_compute()).all(|(_, v)| v.to_i64().is_ok()));
584    let concrete_compute_cost = |op: &StaticName| -> i64 {
585        op_costs
586            .get(op)
587            .map(|costs| {
588                costs
589                    .iter()
590                    .filter(|(k, _)| k.is_compute())
591                    .filter_map(|(_, v)| v.to_i64().ok())
592                    .sum::<i64>()
593            })
594            .unwrap_or(0)
595    };
596    println!("{}", White.bold().paint("# Operators"));
597    for (op, count) in op_counts.iter().sorted_by(|a, b| {
598        if all_costs_concrete {
599            concrete_compute_cost(b.0)
600                .cmp(&concrete_compute_cost(a.0))
601                .then(b.1.cmp(a.1))
602                .then(a.0.cmp(b.0))
603        } else {
604            b.1.cmp(a.1).then(a.0.cmp(b.0))
605        }
606    }) {
607        let cost_str = op_costs
608            .get(op)
609            .map(|costs| {
610                costs.iter().map(|(kind, val)| format!("{kind:?}: {}", render_tdim(val))).join(", ")
611            })
612            .unwrap_or_default();
613        if cost_str.is_empty() {
614            println!("  {count:>5} {op}");
615        } else {
616            println!("  {count:>5} {op}  [{cost_str}]");
617        }
618    }
619    let total_nodes: usize = op_counts.values().sum();
620    if total_cost_str.is_empty() {
621        println!("  {total_nodes:>5} total");
622    } else {
623        println!("  {total_nodes:>5} total  [{total_cost_str}]");
624    }
625    Ok(())
626}
627
628/// Format a rusage::Duration showing avgtime in ms.
629pub fn dur_avg(measure: Duration) -> String {
630    White.bold().paint(format!("{:.3} ms/i", measure.as_secs_f64() * 1e3)).to_string()
631}
632
633/// Format a rusage::Duration showing avgtime in ms, with percentage to a global
634/// one.
635pub fn dur_avg_ratio(measure: Duration, global: Duration) -> String {
636    format!(
637        "{} {}",
638        White.bold().paint(format!("{:7.3} ms/i", measure.as_secs_f64() * 1e3)),
639        Yellow
640            .bold()
641            .paint(format!("{:>4.1}%", measure.as_secs_f64() / global.as_secs_f64() * 100.)),
642    )
643}
644
645fn render_tdim(d: &TDim) -> AnsiString<'static> {
646    if let Ok(i) = d.to_i64() { render_big_integer(i) } else { d.to_string().into() }
647}
648
649fn render_big_integer(i: i64) -> nu_ansi_term::AnsiString<'static> {
650    let raw = i.to_string();
651    let mut blocks = raw
652        .chars()
653        .rev()
654        .chunks(3)
655        .into_iter()
656        .map(|mut c| c.join("").chars().rev().join(""))
657        .enumerate()
658        .map(|(ix, s)| if ix % 2 == 1 { White.bold().paint(s).to_string() } else { s })
659        .collect::<Vec<_>>();
660    blocks.reverse();
661    blocks.into_iter().join("").into()
662}