1use bytemuck::Pod;
2use half::f16;
3use safetensors::Dtype;
4
5pub trait Zero: Sized + core::ops::Add<Self, Output = Self> {
6 fn zero() -> Self;
7}
8
9impl Zero for f32 {
10 fn zero() -> Self {
11 0.0
12 }
13}
14
15impl Zero for f16 {
16 fn zero() -> Self {
17 Self::ZERO
18 }
19}
20
21impl Zero for u8 {
22 fn zero() -> Self {
23 0
24 }
25}
26
27impl Zero for u16 {
28 fn zero() -> Self {
29 0
30 }
31}
32
33impl Zero for u32 {
34 fn zero() -> Self {
35 0
36 }
37}
38
39pub trait One: Sized + core::ops::Mul<Self, Output = Self> {
40 fn one() -> Self;
41}
42
43impl One for f32 {
44 fn one() -> Self {
45 1.0
46 }
47}
48
49impl One for f16 {
50 fn one() -> Self {
51 Self::ONE
52 }
53}
54
55impl One for u8 {
56 fn one() -> Self {
57 1
58 }
59}
60
61impl One for u16 {
62 fn one() -> Self {
63 1
64 }
65}
66
67impl One for u32 {
68 fn one() -> Self {
69 1
70 }
71}
72
73pub trait Scalar: Sized + Clone + Copy + Pod + Zero + One + Send + Sync + sealed::Sealed {
74 fn size() -> usize {
76 std::mem::size_of::<Self>()
77 }
78
79 const DATA_TYPE: Dtype;
80}
81
82impl Scalar for f32 {
83 const DATA_TYPE: Dtype = Dtype::F32;
84}
85impl Scalar for f16 {
86 const DATA_TYPE: Dtype = Dtype::F16;
87}
88impl Scalar for u8 {
89 const DATA_TYPE: Dtype = Dtype::U8;
90}
91impl Scalar for u16 {
92 const DATA_TYPE: Dtype = Dtype::U16;
93}
94impl Scalar for u32 {
95 const DATA_TYPE: Dtype = Dtype::U32;
96}
97
98pub trait Float: Scalar + Hom<f16> + Hom<f32> + CoHom<f16> + CoHom<f32> {
99 const DEF: &'static str;
100}
101
102impl Float for f32 {
103 const DEF: &'static str = "FP32";
104}
105
106impl Float for f16 {
107 const DEF: &'static str = "FP16";
108}
109
110pub trait Hom<Into> {
111 fn hom(self) -> Into;
112}
113
114impl Hom<f32> for f32 {
115 fn hom(self) -> f32 {
116 self
117 }
118}
119
120impl Hom<f16> for f32 {
121 fn hom(self) -> f16 {
122 f16::from_f32(self)
123 }
124}
125
126impl Hom<f32> for f16 {
127 fn hom(self) -> f32 {
128 self.to_f32()
129 }
130}
131
132impl Hom<f16> for f16 {
133 fn hom(self) -> f16 {
134 self
135 }
136}
137
138pub trait CoHom<From> {
139 fn co_hom(value: From) -> Self;
140}
141
142impl<From, Into> CoHom<From> for Into
143where
144 From: Hom<Into>,
145{
146 fn co_hom(value: From) -> Self {
147 value.hom()
148 }
149}
150
151mod sealed {
152 use half::f16;
153
154 pub trait Sealed {}
155
156 impl Sealed for f32 {}
157 impl Sealed for f16 {}
158 impl Sealed for u8 {}
159 impl Sealed for u16 {}
160 impl Sealed for u32 {}
161}