1use half;
4
5#[allow(clippy::upper_case_acronyms)]
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum Kind {
9 Uint8,
10 Int8,
11 Int16,
12 Int,
13 Int64,
14 Half,
15 Float,
16 Double,
17 ComplexHalf,
18 ComplexFloat,
19 ComplexDouble,
20 Bool,
21 QInt8,
22 QUInt8,
23 QInt32,
24 BFloat16,
25 Float8e5m2,
26 Float8e4m3fn,
27 Float8e5m2fnuz,
28 Float8e4m3fnuz,
29}
30
31impl Kind {
32 pub(super) fn c_int(self) -> libc::c_int {
33 match self {
36 Kind::Uint8 => 0,
37 Kind::Int8 => 1,
38 Kind::Int16 => 2,
39 Kind::Int => 3,
40 Kind::Int64 => 4,
41 Kind::Half => 5,
42 Kind::Float => 6,
43 Kind::Double => 7,
44 Kind::ComplexHalf => 8,
45 Kind::ComplexFloat => 9,
46 Kind::ComplexDouble => 10,
47 Kind::Bool => 11,
48 Kind::QInt8 => 12,
49 Kind::QUInt8 => 13,
50 Kind::QInt32 => 14,
51 Kind::BFloat16 => 15,
52 Kind::Float8e5m2 => 23,
53 Kind::Float8e4m3fn => 24,
54 Kind::Float8e5m2fnuz => 25,
55 Kind::Float8e4m3fnuz => 26,
56 }
57 }
58
59 pub(super) fn from_c_int(v: libc::c_int) -> Result<Kind, crate::TchError> {
60 match v {
61 0 => Ok(Kind::Uint8),
62 1 => Ok(Kind::Int8),
63 2 => Ok(Kind::Int16),
64 3 => Ok(Kind::Int),
65 4 => Ok(Kind::Int64),
66 5 => Ok(Kind::Half),
67 6 => Ok(Kind::Float),
68 7 => Ok(Kind::Double),
69 8 => Ok(Kind::ComplexHalf),
70 9 => Ok(Kind::ComplexFloat),
71 10 => Ok(Kind::ComplexDouble),
72 11 => Ok(Kind::Bool),
73 12 => Ok(Kind::QInt8),
74 13 => Ok(Kind::QUInt8),
75 14 => Ok(Kind::QInt32),
76 15 => Ok(Kind::BFloat16),
77 23 => Ok(Kind::Float8e5m2),
78 24 => Ok(Kind::Float8e4m3fn),
79 25 => Ok(Kind::Float8e5m2fnuz),
80 26 => Ok(Kind::Float8e4m3fnuz),
81 _ => Err(crate::TchError::UnknownKind(v)),
82 }
83 }
84
85 pub fn elt_size_in_bytes(self) -> usize {
86 match self {
87 Kind::Uint8 => 1,
88 Kind::Int8 => 1,
89 Kind::Int16 => 2,
90 Kind::Int => 4,
91 Kind::Int64 => 8,
92 Kind::Half => 2,
93 Kind::Float => 4,
94 Kind::Double => 8,
95 Kind::ComplexHalf => 4,
96 Kind::ComplexFloat => 8,
97 Kind::ComplexDouble => 16,
98 Kind::Bool => 1,
99 Kind::QInt8 => 1,
100 Kind::QUInt8 => 1,
101 Kind::QInt32 => 4,
102 Kind::BFloat16 => 2,
103 Kind::Float8e5m2 => 1,
104 Kind::Float8e4m3fn => 1,
105 Kind::Float8e5m2fnuz => 1,
106 Kind::Float8e4m3fnuz => 1,
107 }
108 }
109}
110
111pub const FLOAT_CPU: (Kind, crate::Device) = (Kind::Float, crate::Device::Cpu);
112pub const DOUBLE_CPU: (Kind, crate::Device) = (Kind::Double, crate::Device::Cpu);
113pub const INT64_CPU: (Kind, crate::Device) = (Kind::Int64, crate::Device::Cpu);
114
115pub const FLOAT_CUDA: (Kind, crate::Device) = (Kind::Float, crate::Device::Cuda(0));
116pub const DOUBLE_CUDA: (Kind, crate::Device) = (Kind::Double, crate::Device::Cuda(0));
117pub const INT64_CUDA: (Kind, crate::Device) = (Kind::Int64, crate::Device::Cuda(0));
118
119pub unsafe trait Element: Clone {
124 const KIND: Kind;
125 const ZERO: Self;
126}
127
128unsafe impl Element for u8 {
129 const KIND: Kind = Kind::Uint8;
130 const ZERO: Self = 0;
131}
132
133unsafe impl Element for i8 {
134 const KIND: Kind = Kind::Int8;
135 const ZERO: Self = 0;
136}
137
138unsafe impl Element for i16 {
139 const KIND: Kind = Kind::Int16;
140 const ZERO: Self = 0;
141}
142
143unsafe impl Element for i32 {
144 const KIND: Kind = Kind::Int;
145 const ZERO: Self = 0;
146}
147
148unsafe impl Element for i64 {
149 const KIND: Kind = Kind::Int64;
150 const ZERO: Self = 0;
151}
152
153unsafe impl Element for half::f16 {
154 const KIND: Kind = Kind::Half;
155 const ZERO: Self = half::f16::ZERO;
156}
157
158unsafe impl Element for half::bf16 {
159 const KIND: Kind = Kind::Half;
160 const ZERO: Self = half::bf16::ZERO;
161}
162
163unsafe impl Element for f32 {
164 const KIND: Kind = Kind::Float;
165 const ZERO: Self = 0.;
166}
167
168unsafe impl Element for f64 {
169 const KIND: Kind = Kind::Double;
170 const ZERO: Self = 0.;
171}
172
173unsafe impl Element for bool {
174 const KIND: Kind = Kind::Bool;
175 const ZERO: Self = false;
176}