tch_plus/tensor/
display.rs

1/// Pretty printing of tensors
2/// This implementation should be in line with the PyTorch version.
3/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
4use crate::{Kind, Tensor};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7enum BasicKind {
8    Float,
9    Int,
10    Bool,
11    Complex,
12}
13
14impl BasicKind {
15    fn for_tensor(t: &Tensor) -> BasicKind {
16        match t.f_kind() {
17            Err(_) => BasicKind::Complex,
18            Ok(kind) => match kind {
19                Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => BasicKind::Int,
20                Kind::BFloat16
21                | Kind::QInt8
22                | Kind::QUInt8
23                | Kind::QInt32
24                | Kind::Half
25                | Kind::Float
26                | Kind::Double
27                | Kind::Float8e5m2
28                | Kind::Float8e4m3fn
29                | Kind::Float8e5m2fnuz
30                | Kind::Float8e4m3fnuz => BasicKind::Float,
31                Kind::Bool => BasicKind::Bool,
32                Kind::ComplexHalf | Kind::ComplexFloat | Kind::ComplexDouble => BasicKind::Complex,
33            },
34        }
35    }
36
37    fn _is_floating_point(&self) -> bool {
38        match self {
39            BasicKind::Float => true,
40            BasicKind::Bool | BasicKind::Int | BasicKind::Complex => false,
41        }
42    }
43}
44
45impl std::fmt::Debug for Tensor {
46    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47        if self.defined() {
48            match self.f_kind() {
49                Err(err) => write!(f, "Tensor[{:?}, {:?}]", self.size(), err),
50                Ok(kind) => {
51                    let (is_int, is_float) = match kind {
52                        Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => {
53                            (true, false)
54                        }
55                        Kind::BFloat16
56                        | Kind::QInt8
57                        | Kind::QUInt8
58                        | Kind::QInt32
59                        | Kind::Half
60                        | Kind::Float
61                        | Kind::Double
62                        | Kind::Float8e5m2
63                        | Kind::Float8e4m3fn
64                        | Kind::Float8e5m2fnuz
65                        | Kind::Float8e4m3fnuz => (false, true),
66                        Kind::Bool
67                        | Kind::ComplexHalf
68                        | Kind::ComplexFloat
69                        | Kind::ComplexDouble => (false, false),
70                    };
71                    match (self.size().as_slice(), is_int, is_float) {
72                        ([], true, false) => write!(f, "[{}]", i64::try_from(self).unwrap()),
73                        ([s], true, false) if *s < 10 => {
74                            write!(f, "{:?}", Vec::<i64>::try_from(self).unwrap())
75                        }
76                        ([], false, true) => write!(f, "[{}]", f64::try_from(self).unwrap()),
77                        ([s], false, true) if *s < 10 => {
78                            write!(f, "{:?}", Vec::<f64>::try_from(self).unwrap())
79                        }
80                        _ => write!(f, "Tensor[{:?}, {:?}]", self.size(), kind),
81                    }
82                }
83            }
84        } else {
85            write!(f, "Tensor[Undefined]")
86        }
87    }
88}
89
90/// Options for Tensor pretty printing
91pub struct PrinterOptions {
92    precision: usize,
93    threshold: usize,
94    edge_items: usize,
95    line_width: usize,
96    sci_mode: Option<bool>,
97}
98
99lazy_static! {
100    static ref PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
101        std::sync::Mutex::new(Default::default());
102}
103
104pub fn set_print_options(options: PrinterOptions) {
105    *PRINT_OPTS.lock().unwrap() = options
106}
107
108pub fn set_print_options_default() {
109    *PRINT_OPTS.lock().unwrap() = Default::default()
110}
111
112pub fn set_print_options_short() {
113    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
114        precision: 2,
115        threshold: 1000,
116        edge_items: 2,
117        line_width: 80,
118        sci_mode: None,
119    }
120}
121
122pub fn set_print_options_full() {
123    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
124        precision: 4,
125        threshold: usize::MAX,
126        edge_items: 3,
127        line_width: 80,
128        sci_mode: None,
129    }
130}
131
132impl Default for PrinterOptions {
133    fn default() -> Self {
134        Self { precision: 4, threshold: 1000, edge_items: 3, line_width: 80, sci_mode: None }
135    }
136}
137
138trait TensorFormatter {
139    type Elem;
140
141    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
142
143    fn value(tensor: &Tensor) -> Self::Elem;
144
145    fn values(tensor: &Tensor) -> Vec<Self::Elem>;
146
147    fn max_width(&self, to_display: &Tensor) -> usize {
148        let mut max_width = 1;
149        for v in Self::values(to_display) {
150            let mut fmt_size = FmtSize::new();
151            let _res = self.fmt(v, 1, &mut fmt_size);
152            max_width = usize::max(max_width, fmt_size.final_size())
153        }
154        max_width
155    }
156
157    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
158        writeln!(f)?;
159        for _ in 0..i {
160            write!(f, " ")?
161        }
162        Ok(())
163    }
164
165    fn fmt_tensor(
166        &self,
167        t: &Tensor,
168        indent: usize,
169        max_w: usize,
170        summarize: bool,
171        po: &PrinterOptions,
172        f: &mut std::fmt::Formatter,
173    ) -> std::fmt::Result {
174        let size = t.size();
175        let edge_items = po.edge_items as i64;
176        write!(f, "[")?;
177        match size.as_slice() {
178            [] => self.fmt(Self::value(t), max_w, f)?,
179            [v] if summarize && *v > 2 * edge_items => {
180                for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() {
181                    self.fmt(v, max_w, f)?;
182                    write!(f, ", ")?;
183                }
184                write!(f, "...")?;
185                for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() {
186                    write!(f, ", ")?;
187                    self.fmt(v, max_w, f)?
188                }
189            }
190            [_] => {
191                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
192                for (i, v) in Self::values(t).into_iter().enumerate() {
193                    if i > 0 {
194                        if i % elements_per_line == 0 {
195                            write!(f, ",")?;
196                            Self::write_newline_indent(indent, f)?
197                        } else {
198                            write!(f, ", ")?;
199                        }
200                    }
201                    self.fmt(v, max_w, f)?
202                }
203            }
204            _ => {
205                if summarize && size[0] > 2 * edge_items {
206                    for i in 0..edge_items {
207                        self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
208                        write!(f, ",")?;
209                        Self::write_newline_indent(indent, f)?
210                    }
211                    write!(f, "...")?;
212                    Self::write_newline_indent(indent, f)?;
213                    for i in size[0] - edge_items..size[0] {
214                        self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
215                        if i + 1 != size[0] {
216                            write!(f, ",")?;
217                            Self::write_newline_indent(indent, f)?
218                        }
219                    }
220                } else {
221                    for i in 0..size[0] {
222                        self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
223                        if i + 1 != size[0] {
224                            write!(f, ",")?;
225                            Self::write_newline_indent(indent, f)?
226                        }
227                    }
228                }
229            }
230        }
231        write!(f, "]")?;
232        Ok(())
233    }
234}
235
236struct FloatFormatter {
237    int_mode: bool,
238    sci_mode: bool,
239    precision: usize,
240}
241
242struct FmtSize {
243    current_size: usize,
244}
245
246impl FmtSize {
247    fn new() -> Self {
248        Self { current_size: 0 }
249    }
250
251    fn final_size(self) -> usize {
252        self.current_size
253    }
254}
255
256impl std::fmt::Write for FmtSize {
257    fn write_str(&mut self, s: &str) -> std::fmt::Result {
258        self.current_size += s.len();
259        Ok(())
260    }
261}
262
263impl FloatFormatter {
264    fn new(t: &Tensor, po: &PrinterOptions) -> Self {
265        let mut int_mode = true;
266        let mut sci_mode = false;
267
268        let _guard = crate::no_grad_guard();
269        let t = t.to_device(crate::Device::Cpu);
270
271        // Rather than containing all values, this should only include
272        // values that end up being displayed according to [threshold].
273        let nonzero_finite_vals = {
274            let t = t.reshape([-1]);
275            t.masked_select(&t.isfinite().logical_and(&t.ne(0.)))
276        };
277
278        let values = Vec::<f64>::try_from(&nonzero_finite_vals).unwrap();
279        if nonzero_finite_vals.numel() > 0 {
280            let nonzero_finite_abs = nonzero_finite_vals.abs();
281            let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]);
282            let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]);
283
284            for &value in values.iter() {
285                if value.ceil() != value {
286                    int_mode = false;
287                    break;
288                }
289            }
290
291            sci_mode = nonzero_finite_max / nonzero_finite_min > 1000.
292                || nonzero_finite_max > 1e8
293                || nonzero_finite_min < 1e-4
294        }
295
296        match po.sci_mode {
297            None => {}
298            Some(v) => sci_mode = v,
299        }
300        Self { int_mode, sci_mode, precision: po.precision }
301    }
302}
303
304impl TensorFormatter for FloatFormatter {
305    type Elem = f64;
306
307    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
308        if self.sci_mode {
309            write!(f, "{v:width$.prec$e}", v = v, width = max_w, prec = self.precision)
310        } else if self.int_mode {
311            if v.is_finite() {
312                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
313            } else {
314                write!(f, "{v:max_w$.0}")
315            }
316        } else {
317            write!(f, "{v:width$.prec$}", v = v, width = max_w, prec = self.precision)
318        }
319    }
320
321    fn value(tensor: &Tensor) -> Self::Elem {
322        tensor.double_value(&[])
323    }
324
325    fn values(tensor: &Tensor) -> Vec<Self::Elem> {
326        Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
327    }
328}
329
330struct IntFormatter;
331
332impl TensorFormatter for IntFormatter {
333    type Elem = i64;
334
335    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
336        write!(f, "{v:max_w$}")
337    }
338
339    fn value(tensor: &Tensor) -> Self::Elem {
340        tensor.int64_value(&[])
341    }
342
343    fn values(tensor: &Tensor) -> Vec<Self::Elem> {
344        Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
345    }
346}
347
348struct BoolFormatter;
349
350impl TensorFormatter for BoolFormatter {
351    type Elem = bool;
352
353    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
354        let v = if v { "true" } else { "false" };
355        write!(f, "{v:max_w$}")
356    }
357
358    fn value(tensor: &Tensor) -> Self::Elem {
359        tensor.int64_value(&[]) != 0
360    }
361
362    fn values(tensor: &Tensor) -> Vec<Self::Elem> {
363        Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
364    }
365}
366
367fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor {
368    let size = t.size();
369    if size.is_empty() {
370        t.shallow_clone()
371    } else if size.len() == 1 {
372        if size[0] > 2 * edge_items {
373            Tensor::cat(
374                &[t.slice(0, None, Some(edge_items), 1), t.slice(0, Some(-edge_items), None, 1)],
375                0,
376            )
377        } else {
378            t.shallow_clone()
379        }
380    } else if size[0] > 2 * edge_items {
381        let mut vs: Vec<_> =
382            (0..edge_items).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
383        for i in (size[0] - edge_items)..size[0] {
384            vs.push(get_summarized_data(&t.get(i), edge_items))
385        }
386        Tensor::stack(&vs, 0)
387    } else {
388        let vs: Vec<_> = (0..size[0]).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
389        Tensor::stack(&vs, 0)
390    }
391}
392
393impl std::fmt::Display for Tensor {
394    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
395        if self.defined() {
396            let po = PRINT_OPTS.lock().unwrap();
397            let summarize = self.numel() > po.threshold;
398            let basic_kind = BasicKind::for_tensor(self);
399            let to_display = if summarize {
400                get_summarized_data(self, po.edge_items as i64)
401            } else {
402                self.shallow_clone()
403            };
404            match basic_kind {
405                BasicKind::Int => {
406                    let tf = IntFormatter;
407                    let max_w = tf.max_width(&to_display);
408                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
409                    writeln!(f)?;
410                }
411                BasicKind::Float => {
412                    let tf = FloatFormatter::new(&to_display, &po);
413                    let max_w = tf.max_width(&to_display);
414                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
415                    writeln!(f)?;
416                }
417                BasicKind::Bool => {
418                    let tf = BoolFormatter;
419                    let max_w = tf.max_width(&to_display);
420                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
421                    writeln!(f)?;
422                }
423                BasicKind::Complex => {}
424            };
425            let kind = match self.f_kind() {
426                Ok(kind) => format!("{kind:?}"),
427                Err(err) => format!("{err:?}"),
428            };
429            write!(f, "Tensor[{:?}, {}]", self.size(), kind)
430        } else {
431            write!(f, "Tensor[Undefined]")
432        }
433    }
434}