Skip to main content

vyre_spec/data_type/
display.rs

1//! Display implementations for frozen data-type contracts.
2
3use core::fmt;
4
5use super::{DataType, QuantizationScale, QuantizationZeroPoint};
6
7impl fmt::Display for DataType {
8    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
9        match self {
10            Self::U8 => f.write_str("u8"),
11            Self::U16 => f.write_str("u16"),
12            Self::U32 => f.write_str("u32"),
13            Self::I8 => f.write_str("i8"),
14            Self::I16 => f.write_str("i16"),
15            Self::I32 => f.write_str("i32"),
16            Self::I64 => f.write_str("i64"),
17            Self::U64 => f.write_str("u64"),
18            Self::Vec2U32 => f.write_str("vec2<u32>"),
19            Self::Vec4U32 => f.write_str("vec4<u32>"),
20            Self::Bool => f.write_str("bool"),
21            Self::Bytes => f.write_str("bytes"),
22            Self::Array { element_size } => write!(f, "array<{element_size}B>"),
23            Self::F16 => f.write_str("f16"),
24            Self::BF16 => f.write_str("bf16"),
25            Self::F32 => f.write_str("f32"),
26            Self::F64 => f.write_str("f64"),
27            Self::Tensor => f.write_str("tensor"),
28            Self::Handle(id) => write!(f, "handle<{:#010x}>", id.as_u32()),
29            Self::Vec { element, count } => write!(f, "vec<{element};{count}>"),
30            Self::TensorShaped { element, shape } => {
31                write!(f, "tensor<{element};")?;
32                for (idx, dim) in shape.iter().enumerate() {
33                    if idx > 0 {
34                        f.write_str("x")?;
35                    }
36                    write!(f, "{dim}")?;
37                }
38                f.write_str(">")
39            }
40            Self::Opaque(id) => write!(f, "opaque<{:#010x}>", id.as_u32()),
41            Self::F8E4M3 => f.write_str("f8e4m3"),
42            Self::F8E5M2 => f.write_str("f8e5m2"),
43            Self::I4 => f.write_str("i4"),
44            Self::FP4 => f.write_str("fp4"),
45            Self::NF4 => f.write_str("nf4"),
46            Self::SparseCsr { element } => write!(f, "sparse_csr<{element}>"),
47            Self::SparseCoo { element } => write!(f, "sparse_coo<{element}>"),
48            Self::SparseBsr {
49                element,
50                block_rows,
51                block_cols,
52            } => write!(f, "sparse_bsr<{element};{block_rows}x{block_cols}>"),
53            Self::DeviceMesh { axes } => {
54                f.write_str("device_mesh<")?;
55                for (idx, axis) in axes.iter().enumerate() {
56                    if idx > 0 {
57                        f.write_str("x")?;
58                    }
59                    write!(f, "{axis}")?;
60                }
61                f.write_str(">")
62            }
63            Self::Quantized {
64                storage,
65                scale,
66                zero_point,
67            } => write!(f, "quantized<{storage};{scale};{zero_point}>"),
68        }
69    }
70}
71
72impl fmt::Display for QuantizationScale {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        match self {
75            Self::PerTensor => f.write_str("scale:per_tensor"),
76            Self::PerChannel { axis } => write!(f, "scale:per_channel(axis={axis})"),
77            Self::PerGroup { group_size } => write!(f, "scale:per_group(size={group_size})"),
78        }
79    }
80}
81
82impl fmt::Display for QuantizationZeroPoint {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::Absent => f.write_str("zp:absent"),
86            Self::PerTensor => f.write_str("zp:per_tensor"),
87            Self::PerChannel { axis } => write!(f, "zp:per_channel(axis={axis})"),
88            Self::PerGroup { group_size } => write!(f, "zp:per_group(size={group_size})"),
89        }
90    }
91}