1use crate::{Kind, Tensor};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7enum BasicKind {
8 Float,
9 Int,
10 Bool,
11 Complex,
12}
13
14impl BasicKind {
15 fn for_tensor(t: &Tensor) -> BasicKind {
16 match t.f_kind() {
17 Err(_) => BasicKind::Complex,
18 Ok(kind) => match kind {
19 Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => BasicKind::Int,
20 Kind::BFloat16
21 | Kind::QInt8
22 | Kind::QUInt8
23 | Kind::QInt32
24 | Kind::Half
25 | Kind::Float
26 | Kind::Double
27 | Kind::Float8e5m2
28 | Kind::Float8e4m3fn
29 | Kind::Float8e5m2fnuz
30 | Kind::Float8e4m3fnuz => BasicKind::Float,
31 Kind::Bool => BasicKind::Bool,
32 Kind::ComplexHalf | Kind::ComplexFloat | Kind::ComplexDouble => BasicKind::Complex,
33 },
34 }
35 }
36
37 fn _is_floating_point(&self) -> bool {
38 match self {
39 BasicKind::Float => true,
40 BasicKind::Bool | BasicKind::Int | BasicKind::Complex => false,
41 }
42 }
43}
44
45impl std::fmt::Debug for Tensor {
46 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47 if self.defined() {
48 match self.f_kind() {
49 Err(err) => write!(f, "Tensor[{:?}, {:?}]", self.size(), err),
50 Ok(kind) => {
51 let (is_int, is_float) = match kind {
52 Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => {
53 (true, false)
54 }
55 Kind::BFloat16
56 | Kind::QInt8
57 | Kind::QUInt8
58 | Kind::QInt32
59 | Kind::Half
60 | Kind::Float
61 | Kind::Double
62 | Kind::Float8e5m2
63 | Kind::Float8e4m3fn
64 | Kind::Float8e5m2fnuz
65 | Kind::Float8e4m3fnuz => (false, true),
66 Kind::Bool
67 | Kind::ComplexHalf
68 | Kind::ComplexFloat
69 | Kind::ComplexDouble => (false, false),
70 };
71 match (self.size().as_slice(), is_int, is_float) {
72 ([], true, false) => write!(f, "[{}]", i64::try_from(self).unwrap()),
73 ([s], true, false) if *s < 10 => {
74 write!(f, "{:?}", Vec::<i64>::try_from(self).unwrap())
75 }
76 ([], false, true) => write!(f, "[{}]", f64::try_from(self).unwrap()),
77 ([s], false, true) if *s < 10 => {
78 write!(f, "{:?}", Vec::<f64>::try_from(self).unwrap())
79 }
80 _ => write!(f, "Tensor[{:?}, {:?}]", self.size(), kind),
81 }
82 }
83 }
84 } else {
85 write!(f, "Tensor[Undefined]")
86 }
87 }
88}
89
90pub struct PrinterOptions {
92 precision: usize,
93 threshold: usize,
94 edge_items: usize,
95 line_width: usize,
96 sci_mode: Option<bool>,
97}
98
99lazy_static! {
100 static ref PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
101 std::sync::Mutex::new(Default::default());
102}
103
104pub fn set_print_options(options: PrinterOptions) {
105 *PRINT_OPTS.lock().unwrap() = options
106}
107
108pub fn set_print_options_default() {
109 *PRINT_OPTS.lock().unwrap() = Default::default()
110}
111
112pub fn set_print_options_short() {
113 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
114 precision: 2,
115 threshold: 1000,
116 edge_items: 2,
117 line_width: 80,
118 sci_mode: None,
119 }
120}
121
122pub fn set_print_options_full() {
123 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
124 precision: 4,
125 threshold: usize::MAX,
126 edge_items: 3,
127 line_width: 80,
128 sci_mode: None,
129 }
130}
131
132impl Default for PrinterOptions {
133 fn default() -> Self {
134 Self { precision: 4, threshold: 1000, edge_items: 3, line_width: 80, sci_mode: None }
135 }
136}
137
138trait TensorFormatter {
139 type Elem;
140
141 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
142
143 fn value(tensor: &Tensor) -> Self::Elem;
144
145 fn values(tensor: &Tensor) -> Vec<Self::Elem>;
146
147 fn max_width(&self, to_display: &Tensor) -> usize {
148 let mut max_width = 1;
149 for v in Self::values(to_display) {
150 let mut fmt_size = FmtSize::new();
151 let _res = self.fmt(v, 1, &mut fmt_size);
152 max_width = usize::max(max_width, fmt_size.final_size())
153 }
154 max_width
155 }
156
157 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
158 writeln!(f)?;
159 for _ in 0..i {
160 write!(f, " ")?
161 }
162 Ok(())
163 }
164
165 fn fmt_tensor(
166 &self,
167 t: &Tensor,
168 indent: usize,
169 max_w: usize,
170 summarize: bool,
171 po: &PrinterOptions,
172 f: &mut std::fmt::Formatter,
173 ) -> std::fmt::Result {
174 let size = t.size();
175 let edge_items = po.edge_items as i64;
176 write!(f, "[")?;
177 match size.as_slice() {
178 [] => self.fmt(Self::value(t), max_w, f)?,
179 [v] if summarize && *v > 2 * edge_items => {
180 for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() {
181 self.fmt(v, max_w, f)?;
182 write!(f, ", ")?;
183 }
184 write!(f, "...")?;
185 for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() {
186 write!(f, ", ")?;
187 self.fmt(v, max_w, f)?
188 }
189 }
190 [_] => {
191 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
192 for (i, v) in Self::values(t).into_iter().enumerate() {
193 if i > 0 {
194 if i % elements_per_line == 0 {
195 write!(f, ",")?;
196 Self::write_newline_indent(indent, f)?
197 } else {
198 write!(f, ", ")?;
199 }
200 }
201 self.fmt(v, max_w, f)?
202 }
203 }
204 _ => {
205 if summarize && size[0] > 2 * edge_items {
206 for i in 0..edge_items {
207 self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
208 write!(f, ",")?;
209 Self::write_newline_indent(indent, f)?
210 }
211 write!(f, "...")?;
212 Self::write_newline_indent(indent, f)?;
213 for i in size[0] - edge_items..size[0] {
214 self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
215 if i + 1 != size[0] {
216 write!(f, ",")?;
217 Self::write_newline_indent(indent, f)?
218 }
219 }
220 } else {
221 for i in 0..size[0] {
222 self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
223 if i + 1 != size[0] {
224 write!(f, ",")?;
225 Self::write_newline_indent(indent, f)?
226 }
227 }
228 }
229 }
230 }
231 write!(f, "]")?;
232 Ok(())
233 }
234}
235
236struct FloatFormatter {
237 int_mode: bool,
238 sci_mode: bool,
239 precision: usize,
240}
241
242struct FmtSize {
243 current_size: usize,
244}
245
246impl FmtSize {
247 fn new() -> Self {
248 Self { current_size: 0 }
249 }
250
251 fn final_size(self) -> usize {
252 self.current_size
253 }
254}
255
256impl std::fmt::Write for FmtSize {
257 fn write_str(&mut self, s: &str) -> std::fmt::Result {
258 self.current_size += s.len();
259 Ok(())
260 }
261}
262
263impl FloatFormatter {
264 fn new(t: &Tensor, po: &PrinterOptions) -> Self {
265 let mut int_mode = true;
266 let mut sci_mode = false;
267
268 let _guard = crate::no_grad_guard();
269 let t = t.to_device(crate::Device::Cpu);
270
271 let nonzero_finite_vals = {
274 let t = t.reshape([-1]);
275 t.masked_select(&t.isfinite().logical_and(&t.ne(0.)))
276 };
277
278 let values = Vec::<f64>::try_from(&nonzero_finite_vals).unwrap();
279 if nonzero_finite_vals.numel() > 0 {
280 let nonzero_finite_abs = nonzero_finite_vals.abs();
281 let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]);
282 let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]);
283
284 for &value in values.iter() {
285 if value.ceil() != value {
286 int_mode = false;
287 break;
288 }
289 }
290
291 sci_mode = nonzero_finite_max / nonzero_finite_min > 1000.
292 || nonzero_finite_max > 1e8
293 || nonzero_finite_min < 1e-4
294 }
295
296 match po.sci_mode {
297 None => {}
298 Some(v) => sci_mode = v,
299 }
300 Self { int_mode, sci_mode, precision: po.precision }
301 }
302}
303
304impl TensorFormatter for FloatFormatter {
305 type Elem = f64;
306
307 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
308 if self.sci_mode {
309 write!(f, "{v:width$.prec$e}", v = v, width = max_w, prec = self.precision)
310 } else if self.int_mode {
311 if v.is_finite() {
312 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
313 } else {
314 write!(f, "{v:max_w$.0}")
315 }
316 } else {
317 write!(f, "{v:width$.prec$}", v = v, width = max_w, prec = self.precision)
318 }
319 }
320
321 fn value(tensor: &Tensor) -> Self::Elem {
322 tensor.double_value(&[])
323 }
324
325 fn values(tensor: &Tensor) -> Vec<Self::Elem> {
326 Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
327 }
328}
329
330struct IntFormatter;
331
332impl TensorFormatter for IntFormatter {
333 type Elem = i64;
334
335 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
336 write!(f, "{v:max_w$}")
337 }
338
339 fn value(tensor: &Tensor) -> Self::Elem {
340 tensor.int64_value(&[])
341 }
342
343 fn values(tensor: &Tensor) -> Vec<Self::Elem> {
344 Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
345 }
346}
347
348struct BoolFormatter;
349
350impl TensorFormatter for BoolFormatter {
351 type Elem = bool;
352
353 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
354 let v = if v { "true" } else { "false" };
355 write!(f, "{v:max_w$}")
356 }
357
358 fn value(tensor: &Tensor) -> Self::Elem {
359 tensor.int64_value(&[]) != 0
360 }
361
362 fn values(tensor: &Tensor) -> Vec<Self::Elem> {
363 Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
364 }
365}
366
367fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor {
368 let size = t.size();
369 if size.is_empty() {
370 t.shallow_clone()
371 } else if size.len() == 1 {
372 if size[0] > 2 * edge_items {
373 Tensor::cat(
374 &[t.slice(0, None, Some(edge_items), 1), t.slice(0, Some(-edge_items), None, 1)],
375 0,
376 )
377 } else {
378 t.shallow_clone()
379 }
380 } else if size[0] > 2 * edge_items {
381 let mut vs: Vec<_> =
382 (0..edge_items).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
383 for i in (size[0] - edge_items)..size[0] {
384 vs.push(get_summarized_data(&t.get(i), edge_items))
385 }
386 Tensor::stack(&vs, 0)
387 } else {
388 let vs: Vec<_> = (0..size[0]).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
389 Tensor::stack(&vs, 0)
390 }
391}
392
393impl std::fmt::Display for Tensor {
394 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
395 if self.defined() {
396 let po = PRINT_OPTS.lock().unwrap();
397 let summarize = self.numel() > po.threshold;
398 let basic_kind = BasicKind::for_tensor(self);
399 let to_display = if summarize {
400 get_summarized_data(self, po.edge_items as i64)
401 } else {
402 self.shallow_clone()
403 };
404 match basic_kind {
405 BasicKind::Int => {
406 let tf = IntFormatter;
407 let max_w = tf.max_width(&to_display);
408 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
409 writeln!(f)?;
410 }
411 BasicKind::Float => {
412 let tf = FloatFormatter::new(&to_display, &po);
413 let max_w = tf.max_width(&to_display);
414 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
415 writeln!(f)?;
416 }
417 BasicKind::Bool => {
418 let tf = BoolFormatter;
419 let max_w = tf.max_width(&to_display);
420 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
421 writeln!(f)?;
422 }
423 BasicKind::Complex => {}
424 };
425 let kind = match self.f_kind() {
426 Ok(kind) => format!("{kind:?}"),
427 Err(err) => format!("{err:?}"),
428 };
429 write!(f, "Tensor[{:?}, {}]", self.size(), kind)
430 } else {
431 write!(f, "Tensor[Undefined]")
432 }
433 }
434}