tch_plus/tensor/
display.rs

1/// Pretty printing of tensors
2/// This implementation should be in line with the PyTorch version.
3/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
4use crate::{Device, Kind, Tensor};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7enum BasicKind {
8    Float,
9    Int,
10    Bool,
11    Complex,
12}
13
14impl BasicKind {
15    fn for_tensor(t: &Tensor) -> BasicKind {
16        match t.f_kind() {
17            Err(_) => BasicKind::Complex,
18            Ok(kind) => match kind {
19                Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => BasicKind::Int,
20                Kind::BFloat16
21                | Kind::QInt8
22                | Kind::QUInt8
23                | Kind::QInt32
24                | Kind::Half
25                | Kind::Float
26                | Kind::Double
27                | Kind::Float8e5m2
28                | Kind::Float8e4m3fn
29                | Kind::Float8e5m2fnuz
30                | Kind::Float8e4m3fnuz => BasicKind::Float,
31                Kind::Bool => BasicKind::Bool,
32                Kind::ComplexHalf | Kind::ComplexFloat | Kind::ComplexDouble => BasicKind::Complex,
33            },
34        }
35    }
36
37    fn _is_floating_point(&self) -> bool {
38        match self {
39            BasicKind::Float => true,
40            BasicKind::Bool | BasicKind::Int | BasicKind::Complex => false,
41        }
42    }
43}
44
45impl std::fmt::Debug for Tensor {
46    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47        if self.defined() {
48            if self.device().eq(&Device::Mps) {
49                return write!(f, "Tensor[{:?}]", self.size())
50            }
51
52            match self.f_kind() {
53                Err(err) => write!(f, "Tensor[{:?}, {:?}]", self.size(), err),
54                Ok(kind) => {
55                    let (is_int, is_float) = match kind {
56                        Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => {
57                            (true, false)
58                        }
59                        Kind::BFloat16
60                        | Kind::QInt8
61                        | Kind::QUInt8
62                        | Kind::QInt32
63                        | Kind::Half
64                        | Kind::Float
65                        | Kind::Double
66                        | Kind::Float8e5m2
67                        | Kind::Float8e4m3fn
68                        | Kind::Float8e5m2fnuz
69                        | Kind::Float8e4m3fnuz => (false, true),
70                        Kind::Bool
71                        | Kind::ComplexHalf
72                        | Kind::ComplexFloat
73                        | Kind::ComplexDouble => (false, false),
74                    };
75                    match (self.size().as_slice(), is_int, is_float) {
76                        ([], true, false) => write!(f, "[{}]", i64::try_from(self).unwrap()),
77                        ([s], true, false) if *s < 10 => {
78                            write!(f, "{:?}", Vec::<i64>::try_from(self).unwrap())
79                        }
80                        ([], false, true) => write!(f, "[{}]", f64::try_from(self).unwrap()),
81                        ([s], false, true) if *s < 10 => {
82                            write!(f, "{:?}", Vec::<f64>::try_from(self).unwrap())
83                        }
84                        _ => write!(f, "Tensor[{:?}, {:?}]", self.size(), kind),
85                    }
86                }
87            }
88        } else {
89            write!(f, "Tensor[Undefined]")
90        }
91    }
92}
93
94/// Options for Tensor pretty printing
95pub struct PrinterOptions {
96    precision: usize,
97    threshold: usize,
98    edge_items: usize,
99    line_width: usize,
100    sci_mode: Option<bool>,
101}
102
103lazy_static! {
104    static ref PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
105        std::sync::Mutex::new(Default::default());
106}
107
108pub fn set_print_options(options: PrinterOptions) {
109    *PRINT_OPTS.lock().unwrap() = options
110}
111
112pub fn set_print_options_default() {
113    *PRINT_OPTS.lock().unwrap() = Default::default()
114}
115
116pub fn set_print_options_short() {
117    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
118        precision: 2,
119        threshold: 1000,
120        edge_items: 2,
121        line_width: 80,
122        sci_mode: None,
123    }
124}
125
126pub fn set_print_options_full() {
127    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
128        precision: 4,
129        threshold: usize::MAX,
130        edge_items: 3,
131        line_width: 80,
132        sci_mode: None,
133    }
134}
135
136impl Default for PrinterOptions {
137    fn default() -> Self {
138        Self { precision: 4, threshold: 1000, edge_items: 3, line_width: 80, sci_mode: None }
139    }
140}
141
142trait TensorFormatter {
143    type Elem;
144
145    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
146
147    fn value(tensor: &Tensor) -> Self::Elem;
148
149    fn values(tensor: &Tensor) -> Vec<Self::Elem>;
150
151    fn max_width(&self, to_display: &Tensor) -> usize {
152        let mut max_width = 1;
153        for v in Self::values(to_display) {
154            let mut fmt_size = FmtSize::new();
155            let _res = self.fmt(v, 1, &mut fmt_size);
156            max_width = usize::max(max_width, fmt_size.final_size())
157        }
158        max_width
159    }
160
161    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162        writeln!(f)?;
163        for _ in 0..i {
164            write!(f, " ")?
165        }
166        Ok(())
167    }
168
169    fn fmt_tensor(
170        &self,
171        t: &Tensor,
172        indent: usize,
173        max_w: usize,
174        summarize: bool,
175        po: &PrinterOptions,
176        f: &mut std::fmt::Formatter,
177    ) -> std::fmt::Result {
178        let size = t.size();
179        let edge_items = po.edge_items as i64;
180        write!(f, "[")?;
181        match size.as_slice() {
182            [] => self.fmt(Self::value(t), max_w, f)?,
183            [v] if summarize && *v > 2 * edge_items => {
184                for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() {
185                    self.fmt(v, max_w, f)?;
186                    write!(f, ", ")?;
187                }
188                write!(f, "...")?;
189                for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() {
190                    write!(f, ", ")?;
191                    self.fmt(v, max_w, f)?
192                }
193            }
194            [_] => {
195                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
196                for (i, v) in Self::values(t).into_iter().enumerate() {
197                    if i > 0 {
198                        if i % elements_per_line == 0 {
199                            write!(f, ",")?;
200                            Self::write_newline_indent(indent, f)?
201                        } else {
202                            write!(f, ", ")?;
203                        }
204                    }
205                    self.fmt(v, max_w, f)?
206                }
207            }
208            _ => {
209                if summarize && size[0] > 2 * edge_items {
210                    for i in 0..edge_items {
211                        self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
212                        write!(f, ",")?;
213                        Self::write_newline_indent(indent, f)?
214                    }
215                    write!(f, "...")?;
216                    Self::write_newline_indent(indent, f)?;
217                    for i in size[0] - edge_items..size[0] {
218                        self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
219                        if i + 1 != size[0] {
220                            write!(f, ",")?;
221                            Self::write_newline_indent(indent, f)?
222                        }
223                    }
224                } else {
225                    for i in 0..size[0] {
226                        self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
227                        if i + 1 != size[0] {
228                            write!(f, ",")?;
229                            Self::write_newline_indent(indent, f)?
230                        }
231                    }
232                }
233            }
234        }
235        write!(f, "]")?;
236        Ok(())
237    }
238}
239
240struct FloatFormatter {
241    int_mode: bool,
242    sci_mode: bool,
243    precision: usize,
244}
245
246struct FmtSize {
247    current_size: usize,
248}
249
250impl FmtSize {
251    fn new() -> Self {
252        Self { current_size: 0 }
253    }
254
255    fn final_size(self) -> usize {
256        self.current_size
257    }
258}
259
260impl std::fmt::Write for FmtSize {
261    fn write_str(&mut self, s: &str) -> std::fmt::Result {
262        self.current_size += s.len();
263        Ok(())
264    }
265}
266
267impl FloatFormatter {
268    fn new(t: &Tensor, po: &PrinterOptions) -> Self {
269        let mut int_mode = true;
270        let mut sci_mode = false;
271
272        let _guard = crate::no_grad_guard();
273        let t = t.to_device(crate::Device::Cpu);
274
275        // Rather than containing all values, this should only include
276        // values that end up being displayed according to [threshold].
277        let nonzero_finite_vals = {
278            let t = t.reshape([-1]);
279            t.masked_select(&t.isfinite().logical_and(&t.ne(0.)))
280        };
281
282        let values = Vec::<f64>::try_from(&nonzero_finite_vals).unwrap();
283        if nonzero_finite_vals.numel() > 0 {
284            let nonzero_finite_abs = nonzero_finite_vals.abs();
285            let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]);
286            let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]);
287
288            for &value in values.iter() {
289                if value.ceil() != value {
290                    int_mode = false;
291                    break;
292                }
293            }
294
295            sci_mode = nonzero_finite_max / nonzero_finite_min > 1000.
296                || nonzero_finite_max > 1e8
297                || nonzero_finite_min < 1e-4
298        }
299
300        match po.sci_mode {
301            None => {}
302            Some(v) => sci_mode = v,
303        }
304        Self { int_mode, sci_mode, precision: po.precision }
305    }
306}
307
308impl TensorFormatter for FloatFormatter {
309    type Elem = f64;
310
311    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
312        if self.sci_mode {
313            write!(f, "{v:width$.prec$e}", v = v, width = max_w, prec = self.precision)
314        } else if self.int_mode {
315            if v.is_finite() {
316                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
317            } else {
318                write!(f, "{v:max_w$.0}")
319            }
320        } else {
321            write!(f, "{v:width$.prec$}", v = v, width = max_w, prec = self.precision)
322        }
323    }
324
325    fn value(tensor: &Tensor) -> Self::Elem {
326        tensor.double_value(&[])
327    }
328
329    fn values(tensor: &Tensor) -> Vec<Self::Elem> {
330        Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
331    }
332}
333
334struct IntFormatter;
335
336impl TensorFormatter for IntFormatter {
337    type Elem = i64;
338
339    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
340        write!(f, "{v:max_w$}")
341    }
342
343    fn value(tensor: &Tensor) -> Self::Elem {
344        tensor.int64_value(&[])
345    }
346
347    fn values(tensor: &Tensor) -> Vec<Self::Elem> {
348        Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
349    }
350}
351
352struct BoolFormatter;
353
354impl TensorFormatter for BoolFormatter {
355    type Elem = bool;
356
357    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
358        let v = if v { "true" } else { "false" };
359        write!(f, "{v:max_w$}")
360    }
361
362    fn value(tensor: &Tensor) -> Self::Elem {
363        tensor.int64_value(&[]) != 0
364    }
365
366    fn values(tensor: &Tensor) -> Vec<Self::Elem> {
367        Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
368    }
369}
370
371fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor {
372    let size = t.size();
373    if size.is_empty() {
374        t.shallow_clone()
375    } else if size.len() == 1 {
376        if size[0] > 2 * edge_items {
377            Tensor::cat(
378                &[t.slice(0, None, Some(edge_items), 1), t.slice(0, Some(-edge_items), None, 1)],
379                0,
380            )
381        } else {
382            t.shallow_clone()
383        }
384    } else if size[0] > 2 * edge_items {
385        let mut vs: Vec<_> =
386            (0..edge_items).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
387        for i in (size[0] - edge_items)..size[0] {
388            vs.push(get_summarized_data(&t.get(i), edge_items))
389        }
390        Tensor::stack(&vs, 0)
391    } else {
392        let vs: Vec<_> = (0..size[0]).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
393        Tensor::stack(&vs, 0)
394    }
395}
396
397impl std::fmt::Display for Tensor {
398    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
399        if self.defined() {
400            let po = PRINT_OPTS.lock().unwrap();
401            let summarize = self.numel() > po.threshold;
402            let basic_kind = BasicKind::for_tensor(self);
403            let to_display = if summarize {
404                get_summarized_data(self, po.edge_items as i64)
405            } else {
406                self.shallow_clone()
407            };
408            match basic_kind {
409                BasicKind::Int => {
410                    let tf = IntFormatter;
411                    let max_w = tf.max_width(&to_display);
412                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
413                    writeln!(f)?;
414                }
415                BasicKind::Float => {
416                    let tf = FloatFormatter::new(&to_display, &po);
417                    let max_w = tf.max_width(&to_display);
418                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
419                    writeln!(f)?;
420                }
421                BasicKind::Bool => {
422                    let tf = BoolFormatter;
423                    let max_w = tf.max_width(&to_display);
424                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
425                    writeln!(f)?;
426                }
427                BasicKind::Complex => {}
428            };
429            let kind = match self.f_kind() {
430                Ok(kind) => format!("{kind:?}"),
431                Err(err) => format!("{err:?}"),
432            };
433            write!(f, "Tensor[{:?}, {}]", self.size(), kind)
434        } else {
435            write!(f, "Tensor[Undefined]")
436        }
437    }
438}