1use crate::{Device, Kind, Tensor};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7enum BasicKind {
8 Float,
9 Int,
10 Bool,
11 Complex,
12}
13
14impl BasicKind {
15 fn for_tensor(t: &Tensor) -> BasicKind {
16 match t.f_kind() {
17 Err(_) => BasicKind::Complex,
18 Ok(kind) => match kind {
19 Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => BasicKind::Int,
20 Kind::BFloat16
21 | Kind::QInt8
22 | Kind::QUInt8
23 | Kind::QInt32
24 | Kind::Half
25 | Kind::Float
26 | Kind::Double
27 | Kind::Float8e5m2
28 | Kind::Float8e4m3fn
29 | Kind::Float8e5m2fnuz
30 | Kind::Float8e4m3fnuz => BasicKind::Float,
31 Kind::Bool => BasicKind::Bool,
32 Kind::ComplexHalf | Kind::ComplexFloat | Kind::ComplexDouble => BasicKind::Complex,
33 },
34 }
35 }
36
37 fn _is_floating_point(&self) -> bool {
38 match self {
39 BasicKind::Float => true,
40 BasicKind::Bool | BasicKind::Int | BasicKind::Complex => false,
41 }
42 }
43}
44
45impl std::fmt::Debug for Tensor {
46 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
47 if self.defined() {
48 if self.device().eq(&Device::Mps) {
49 return write!(f, "Tensor[{:?}]", self.size())
50 }
51
52 match self.f_kind() {
53 Err(err) => write!(f, "Tensor[{:?}, {:?}]", self.size(), err),
54 Ok(kind) => {
55 let (is_int, is_float) = match kind {
56 Kind::Int | Kind::Int8 | Kind::Uint8 | Kind::Int16 | Kind::Int64 => {
57 (true, false)
58 }
59 Kind::BFloat16
60 | Kind::QInt8
61 | Kind::QUInt8
62 | Kind::QInt32
63 | Kind::Half
64 | Kind::Float
65 | Kind::Double
66 | Kind::Float8e5m2
67 | Kind::Float8e4m3fn
68 | Kind::Float8e5m2fnuz
69 | Kind::Float8e4m3fnuz => (false, true),
70 Kind::Bool
71 | Kind::ComplexHalf
72 | Kind::ComplexFloat
73 | Kind::ComplexDouble => (false, false),
74 };
75 match (self.size().as_slice(), is_int, is_float) {
76 ([], true, false) => write!(f, "[{}]", i64::try_from(self).unwrap()),
77 ([s], true, false) if *s < 10 => {
78 write!(f, "{:?}", Vec::<i64>::try_from(self).unwrap())
79 }
80 ([], false, true) => write!(f, "[{}]", f64::try_from(self).unwrap()),
81 ([s], false, true) if *s < 10 => {
82 write!(f, "{:?}", Vec::<f64>::try_from(self).unwrap())
83 }
84 _ => write!(f, "Tensor[{:?}, {:?}]", self.size(), kind),
85 }
86 }
87 }
88 } else {
89 write!(f, "Tensor[Undefined]")
90 }
91 }
92}
93
94pub struct PrinterOptions {
96 precision: usize,
97 threshold: usize,
98 edge_items: usize,
99 line_width: usize,
100 sci_mode: Option<bool>,
101}
102
103lazy_static! {
104 static ref PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
105 std::sync::Mutex::new(Default::default());
106}
107
108pub fn set_print_options(options: PrinterOptions) {
109 *PRINT_OPTS.lock().unwrap() = options
110}
111
112pub fn set_print_options_default() {
113 *PRINT_OPTS.lock().unwrap() = Default::default()
114}
115
116pub fn set_print_options_short() {
117 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
118 precision: 2,
119 threshold: 1000,
120 edge_items: 2,
121 line_width: 80,
122 sci_mode: None,
123 }
124}
125
126pub fn set_print_options_full() {
127 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
128 precision: 4,
129 threshold: usize::MAX,
130 edge_items: 3,
131 line_width: 80,
132 sci_mode: None,
133 }
134}
135
136impl Default for PrinterOptions {
137 fn default() -> Self {
138 Self { precision: 4, threshold: 1000, edge_items: 3, line_width: 80, sci_mode: None }
139 }
140}
141
142trait TensorFormatter {
143 type Elem;
144
145 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
146
147 fn value(tensor: &Tensor) -> Self::Elem;
148
149 fn values(tensor: &Tensor) -> Vec<Self::Elem>;
150
151 fn max_width(&self, to_display: &Tensor) -> usize {
152 let mut max_width = 1;
153 for v in Self::values(to_display) {
154 let mut fmt_size = FmtSize::new();
155 let _res = self.fmt(v, 1, &mut fmt_size);
156 max_width = usize::max(max_width, fmt_size.final_size())
157 }
158 max_width
159 }
160
161 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162 writeln!(f)?;
163 for _ in 0..i {
164 write!(f, " ")?
165 }
166 Ok(())
167 }
168
169 fn fmt_tensor(
170 &self,
171 t: &Tensor,
172 indent: usize,
173 max_w: usize,
174 summarize: bool,
175 po: &PrinterOptions,
176 f: &mut std::fmt::Formatter,
177 ) -> std::fmt::Result {
178 let size = t.size();
179 let edge_items = po.edge_items as i64;
180 write!(f, "[")?;
181 match size.as_slice() {
182 [] => self.fmt(Self::value(t), max_w, f)?,
183 [v] if summarize && *v > 2 * edge_items => {
184 for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() {
185 self.fmt(v, max_w, f)?;
186 write!(f, ", ")?;
187 }
188 write!(f, "...")?;
189 for v in Self::values(&t.slice(0, Some(-edge_items), None, 1)).into_iter() {
190 write!(f, ", ")?;
191 self.fmt(v, max_w, f)?
192 }
193 }
194 [_] => {
195 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
196 for (i, v) in Self::values(t).into_iter().enumerate() {
197 if i > 0 {
198 if i % elements_per_line == 0 {
199 write!(f, ",")?;
200 Self::write_newline_indent(indent, f)?
201 } else {
202 write!(f, ", ")?;
203 }
204 }
205 self.fmt(v, max_w, f)?
206 }
207 }
208 _ => {
209 if summarize && size[0] > 2 * edge_items {
210 for i in 0..edge_items {
211 self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
212 write!(f, ",")?;
213 Self::write_newline_indent(indent, f)?
214 }
215 write!(f, "...")?;
216 Self::write_newline_indent(indent, f)?;
217 for i in size[0] - edge_items..size[0] {
218 self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
219 if i + 1 != size[0] {
220 write!(f, ",")?;
221 Self::write_newline_indent(indent, f)?
222 }
223 }
224 } else {
225 for i in 0..size[0] {
226 self.fmt_tensor(&t.get(i), indent + 1, max_w, summarize, po, f)?;
227 if i + 1 != size[0] {
228 write!(f, ",")?;
229 Self::write_newline_indent(indent, f)?
230 }
231 }
232 }
233 }
234 }
235 write!(f, "]")?;
236 Ok(())
237 }
238}
239
240struct FloatFormatter {
241 int_mode: bool,
242 sci_mode: bool,
243 precision: usize,
244}
245
246struct FmtSize {
247 current_size: usize,
248}
249
250impl FmtSize {
251 fn new() -> Self {
252 Self { current_size: 0 }
253 }
254
255 fn final_size(self) -> usize {
256 self.current_size
257 }
258}
259
260impl std::fmt::Write for FmtSize {
261 fn write_str(&mut self, s: &str) -> std::fmt::Result {
262 self.current_size += s.len();
263 Ok(())
264 }
265}
266
267impl FloatFormatter {
268 fn new(t: &Tensor, po: &PrinterOptions) -> Self {
269 let mut int_mode = true;
270 let mut sci_mode = false;
271
272 let _guard = crate::no_grad_guard();
273 let t = t.to_device(crate::Device::Cpu);
274
275 let nonzero_finite_vals = {
278 let t = t.reshape([-1]);
279 t.masked_select(&t.isfinite().logical_and(&t.ne(0.)))
280 };
281
282 let values = Vec::<f64>::try_from(&nonzero_finite_vals).unwrap();
283 if nonzero_finite_vals.numel() > 0 {
284 let nonzero_finite_abs = nonzero_finite_vals.abs();
285 let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]);
286 let nonzero_finite_max = nonzero_finite_abs.max().double_value(&[]);
287
288 for &value in values.iter() {
289 if value.ceil() != value {
290 int_mode = false;
291 break;
292 }
293 }
294
295 sci_mode = nonzero_finite_max / nonzero_finite_min > 1000.
296 || nonzero_finite_max > 1e8
297 || nonzero_finite_min < 1e-4
298 }
299
300 match po.sci_mode {
301 None => {}
302 Some(v) => sci_mode = v,
303 }
304 Self { int_mode, sci_mode, precision: po.precision }
305 }
306}
307
308impl TensorFormatter for FloatFormatter {
309 type Elem = f64;
310
311 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
312 if self.sci_mode {
313 write!(f, "{v:width$.prec$e}", v = v, width = max_w, prec = self.precision)
314 } else if self.int_mode {
315 if v.is_finite() {
316 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
317 } else {
318 write!(f, "{v:max_w$.0}")
319 }
320 } else {
321 write!(f, "{v:width$.prec$}", v = v, width = max_w, prec = self.precision)
322 }
323 }
324
325 fn value(tensor: &Tensor) -> Self::Elem {
326 tensor.double_value(&[])
327 }
328
329 fn values(tensor: &Tensor) -> Vec<Self::Elem> {
330 Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
331 }
332}
333
334struct IntFormatter;
335
336impl TensorFormatter for IntFormatter {
337 type Elem = i64;
338
339 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
340 write!(f, "{v:max_w$}")
341 }
342
343 fn value(tensor: &Tensor) -> Self::Elem {
344 tensor.int64_value(&[])
345 }
346
347 fn values(tensor: &Tensor) -> Vec<Self::Elem> {
348 Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
349 }
350}
351
352struct BoolFormatter;
353
354impl TensorFormatter for BoolFormatter {
355 type Elem = bool;
356
357 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
358 let v = if v { "true" } else { "false" };
359 write!(f, "{v:max_w$}")
360 }
361
362 fn value(tensor: &Tensor) -> Self::Elem {
363 tensor.int64_value(&[]) != 0
364 }
365
366 fn values(tensor: &Tensor) -> Vec<Self::Elem> {
367 Vec::<Self::Elem>::try_from(tensor.reshape(-1)).unwrap()
368 }
369}
370
371fn get_summarized_data(t: &Tensor, edge_items: i64) -> Tensor {
372 let size = t.size();
373 if size.is_empty() {
374 t.shallow_clone()
375 } else if size.len() == 1 {
376 if size[0] > 2 * edge_items {
377 Tensor::cat(
378 &[t.slice(0, None, Some(edge_items), 1), t.slice(0, Some(-edge_items), None, 1)],
379 0,
380 )
381 } else {
382 t.shallow_clone()
383 }
384 } else if size[0] > 2 * edge_items {
385 let mut vs: Vec<_> =
386 (0..edge_items).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
387 for i in (size[0] - edge_items)..size[0] {
388 vs.push(get_summarized_data(&t.get(i), edge_items))
389 }
390 Tensor::stack(&vs, 0)
391 } else {
392 let vs: Vec<_> = (0..size[0]).map(|i| get_summarized_data(&t.get(i), edge_items)).collect();
393 Tensor::stack(&vs, 0)
394 }
395}
396
397impl std::fmt::Display for Tensor {
398 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
399 if self.defined() {
400 let po = PRINT_OPTS.lock().unwrap();
401 let summarize = self.numel() > po.threshold;
402 let basic_kind = BasicKind::for_tensor(self);
403 let to_display = if summarize {
404 get_summarized_data(self, po.edge_items as i64)
405 } else {
406 self.shallow_clone()
407 };
408 match basic_kind {
409 BasicKind::Int => {
410 let tf = IntFormatter;
411 let max_w = tf.max_width(&to_display);
412 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
413 writeln!(f)?;
414 }
415 BasicKind::Float => {
416 let tf = FloatFormatter::new(&to_display, &po);
417 let max_w = tf.max_width(&to_display);
418 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
419 writeln!(f)?;
420 }
421 BasicKind::Bool => {
422 let tf = BoolFormatter;
423 let max_w = tf.max_width(&to_display);
424 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
425 writeln!(f)?;
426 }
427 BasicKind::Complex => {}
428 };
429 let kind = match self.f_kind() {
430 Ok(kind) => format!("{kind:?}"),
431 Err(err) => format!("{err:?}"),
432 };
433 write!(f, "Tensor[{:?}, {}]", self.size(), kind)
434 } else {
435 write!(f, "Tensor[Undefined]")
436 }
437 }
438}