tch_plus/wrappers/
kind.rs

1//! The different kind of elements supported in Torch.
2
3use half;
4
5/// The different kind of elements that a Tensor can hold.
6#[allow(clippy::upper_case_acronyms)]
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum Kind {
9    Uint8,
10    Int8,
11    Int16,
12    Int,
13    Int64,
14    Half,
15    Float,
16    Double,
17    ComplexHalf,
18    ComplexFloat,
19    ComplexDouble,
20    Bool,
21    QInt8,
22    QUInt8,
23    QInt32,
24    BFloat16,
25    Float8e5m2,
26    Float8e4m3fn,
27    Float8e5m2fnuz,
28    Float8e4m3fnuz,
29}
30
31impl Kind {
32    pub(super) fn c_int(self) -> libc::c_int {
33        // These values should be in sync with include/c10/core/ScalarType.h
34        // https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/c10/core/ScalarType.h#L57
35        match self {
36            Kind::Uint8 => 0,
37            Kind::Int8 => 1,
38            Kind::Int16 => 2,
39            Kind::Int => 3,
40            Kind::Int64 => 4,
41            Kind::Half => 5,
42            Kind::Float => 6,
43            Kind::Double => 7,
44            Kind::ComplexHalf => 8,
45            Kind::ComplexFloat => 9,
46            Kind::ComplexDouble => 10,
47            Kind::Bool => 11,
48            Kind::QInt8 => 12,
49            Kind::QUInt8 => 13,
50            Kind::QInt32 => 14,
51            Kind::BFloat16 => 15,
52            Kind::Float8e5m2 => 23,
53            Kind::Float8e4m3fn => 24,
54            Kind::Float8e5m2fnuz => 25,
55            Kind::Float8e4m3fnuz => 26,
56        }
57    }
58
59    pub(super) fn from_c_int(v: libc::c_int) -> Result<Kind, crate::TchError> {
60        match v {
61            0 => Ok(Kind::Uint8),
62            1 => Ok(Kind::Int8),
63            2 => Ok(Kind::Int16),
64            3 => Ok(Kind::Int),
65            4 => Ok(Kind::Int64),
66            5 => Ok(Kind::Half),
67            6 => Ok(Kind::Float),
68            7 => Ok(Kind::Double),
69            8 => Ok(Kind::ComplexHalf),
70            9 => Ok(Kind::ComplexFloat),
71            10 => Ok(Kind::ComplexDouble),
72            11 => Ok(Kind::Bool),
73            12 => Ok(Kind::QInt8),
74            13 => Ok(Kind::QUInt8),
75            14 => Ok(Kind::QInt32),
76            15 => Ok(Kind::BFloat16),
77            23 => Ok(Kind::Float8e5m2),
78            24 => Ok(Kind::Float8e4m3fn),
79            25 => Ok(Kind::Float8e5m2fnuz),
80            26 => Ok(Kind::Float8e4m3fnuz),
81            _ => Err(crate::TchError::UnknownKind(v)),
82        }
83    }
84
85    pub fn elt_size_in_bytes(self) -> usize {
86        match self {
87            Kind::Uint8 => 1,
88            Kind::Int8 => 1,
89            Kind::Int16 => 2,
90            Kind::Int => 4,
91            Kind::Int64 => 8,
92            Kind::Half => 2,
93            Kind::Float => 4,
94            Kind::Double => 8,
95            Kind::ComplexHalf => 4,
96            Kind::ComplexFloat => 8,
97            Kind::ComplexDouble => 16,
98            Kind::Bool => 1,
99            Kind::QInt8 => 1,
100            Kind::QUInt8 => 1,
101            Kind::QInt32 => 4,
102            Kind::BFloat16 => 2,
103            Kind::Float8e5m2 => 1,
104            Kind::Float8e4m3fn => 1,
105            Kind::Float8e5m2fnuz => 1,
106            Kind::Float8e4m3fnuz => 1,
107        }
108    }
109}
110
111pub const FLOAT_CPU: (Kind, crate::Device) = (Kind::Float, crate::Device::Cpu);
112pub const DOUBLE_CPU: (Kind, crate::Device) = (Kind::Double, crate::Device::Cpu);
113pub const INT64_CPU: (Kind, crate::Device) = (Kind::Int64, crate::Device::Cpu);
114
115pub const FLOAT_CUDA: (Kind, crate::Device) = (Kind::Float, crate::Device::Cuda(0));
116pub const DOUBLE_CUDA: (Kind, crate::Device) = (Kind::Double, crate::Device::Cuda(0));
117pub const INT64_CUDA: (Kind, crate::Device) = (Kind::Int64, crate::Device::Cuda(0));
118
119/// Kinds for tensor elements
120///
121/// # Safety
122/// The specified Kind must be for a type that has the same length as Self.
123pub unsafe trait Element: Clone {
124    const KIND: Kind;
125    const ZERO: Self;
126}
127
128unsafe impl Element for u8 {
129    const KIND: Kind = Kind::Uint8;
130    const ZERO: Self = 0;
131}
132
133unsafe impl Element for i8 {
134    const KIND: Kind = Kind::Int8;
135    const ZERO: Self = 0;
136}
137
138unsafe impl Element for i16 {
139    const KIND: Kind = Kind::Int16;
140    const ZERO: Self = 0;
141}
142
143unsafe impl Element for i32 {
144    const KIND: Kind = Kind::Int;
145    const ZERO: Self = 0;
146}
147
148unsafe impl Element for i64 {
149    const KIND: Kind = Kind::Int64;
150    const ZERO: Self = 0;
151}
152
153unsafe impl Element for half::f16 {
154    const KIND: Kind = Kind::Half;
155    const ZERO: Self = half::f16::ZERO;
156}
157
158unsafe impl Element for half::bf16 {
159    const KIND: Kind = Kind::Half;
160    const ZERO: Self = half::bf16::ZERO;
161}
162
163unsafe impl Element for f32 {
164    const KIND: Kind = Kind::Float;
165    const ZERO: Self = 0.;
166}
167
168unsafe impl Element for f64 {
169    const KIND: Kind = Kind::Double;
170    const ZERO: Self = 0.;
171}
172
173unsafe impl Element for bool {
174    const KIND: Kind = Kind::Bool;
175    const ZERO: Self = false;
176}