web_rwkv/
num.rs

1use bytemuck::Pod;
2use half::f16;
3use safetensors::Dtype;
4
5pub trait Zero: Sized + core::ops::Add<Self, Output = Self> {
6    fn zero() -> Self;
7}
8
9impl Zero for f32 {
10    fn zero() -> Self {
11        0.0
12    }
13}
14
15impl Zero for f16 {
16    fn zero() -> Self {
17        Self::ZERO
18    }
19}
20
21impl Zero for u8 {
22    fn zero() -> Self {
23        0
24    }
25}
26
27impl Zero for u16 {
28    fn zero() -> Self {
29        0
30    }
31}
32
33impl Zero for u32 {
34    fn zero() -> Self {
35        0
36    }
37}
38
39pub trait One: Sized + core::ops::Mul<Self, Output = Self> {
40    fn one() -> Self;
41}
42
43impl One for f32 {
44    fn one() -> Self {
45        1.0
46    }
47}
48
49impl One for f16 {
50    fn one() -> Self {
51        Self::ONE
52    }
53}
54
55impl One for u8 {
56    fn one() -> Self {
57        1
58    }
59}
60
61impl One for u16 {
62    fn one() -> Self {
63        1
64    }
65}
66
67impl One for u32 {
68    fn one() -> Self {
69        1
70    }
71}
72
73pub trait Scalar: Sized + Clone + Copy + Pod + Zero + One + Send + Sync + sealed::Sealed {
74    /// Size of the type in bytes.
75    fn size() -> usize {
76        std::mem::size_of::<Self>()
77    }
78
79    const DATA_TYPE: Dtype;
80}
81
82impl Scalar for f32 {
83    const DATA_TYPE: Dtype = Dtype::F32;
84}
85impl Scalar for f16 {
86    const DATA_TYPE: Dtype = Dtype::F16;
87}
88impl Scalar for u8 {
89    const DATA_TYPE: Dtype = Dtype::U8;
90}
91impl Scalar for u16 {
92    const DATA_TYPE: Dtype = Dtype::U16;
93}
94impl Scalar for u32 {
95    const DATA_TYPE: Dtype = Dtype::U32;
96}
97
98pub trait Float: Scalar + Hom<f16> + Hom<f32> + CoHom<f16> + CoHom<f32> {
99    const DEF: &'static str;
100}
101
102impl Float for f32 {
103    const DEF: &'static str = "FP32";
104}
105
106impl Float for f16 {
107    const DEF: &'static str = "FP16";
108}
109
110pub trait Hom<Into> {
111    fn hom(self) -> Into;
112}
113
114impl Hom<f32> for f32 {
115    fn hom(self) -> f32 {
116        self
117    }
118}
119
120impl Hom<f16> for f32 {
121    fn hom(self) -> f16 {
122        f16::from_f32(self)
123    }
124}
125
126impl Hom<f32> for f16 {
127    fn hom(self) -> f32 {
128        self.to_f32()
129    }
130}
131
132impl Hom<f16> for f16 {
133    fn hom(self) -> f16 {
134        self
135    }
136}
137
138pub trait CoHom<From> {
139    fn co_hom(value: From) -> Self;
140}
141
142impl<From, Into> CoHom<From> for Into
143where
144    From: Hom<Into>,
145{
146    fn co_hom(value: From) -> Self {
147        value.hom()
148    }
149}
150
151mod sealed {
152    use half::f16;
153
154    pub trait Sealed {}
155
156    impl Sealed for f32 {}
157    impl Sealed for f16 {}
158    impl Sealed for u8 {}
159    impl Sealed for u16 {}
160    impl Sealed for u32 {}
161}