tract_core/ops/nn/softmax/
fixedpoint.rs1pub use num_traits::{AsPrimitive, PrimInt};
2use std::fmt::{Binary, Debug, LowerHex};
3
4use super::math::*;
5
6macro_rules! impl_fixed_point_func_unary {
7 ($func_name: ident) => {
8 #[allow(dead_code)]
9 pub fn $func_name(&self) -> Self {
10 Self::from_raw($func_name(self.as_raw()))
11 }
12 };
13}
14
15macro_rules! impl_fixed_point_func_binary {
16 ($func_name: ident) => {
17 pub fn $func_name(&self, b: Self) -> Self {
18 Self::from_raw($func_name(self.as_raw(), b.as_raw()))
19 }
20 };
21}
22
23pub type Q0_31 = FixedPoint<i32, 0>;
24pub type Q1_30 = FixedPoint<i32, 1>;
25pub type Q2_29 = FixedPoint<i32, 2>;
26pub type Q5_26 = FixedPoint<i32, 5>;
27
28#[derive(PartialEq, Eq,PartialOrd, Copy, Clone)]
29pub struct FixedPoint<T: PrimInt, const INTEGER_BITS: usize>(T);
30
31impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
32where
33 T: PrimInt,
34{
35 pub fn from_raw(x: T) -> Self {
36 Self(x)
37 }
38
39 pub fn one() -> Self {
40 if INTEGER_BITS == 0 {
41 Self(T::max_value())
42 } else {
43 Self(T::one() << Self::fractional_bits())
44 }
45 }
46
47 pub fn fractional_bits() -> usize {
48 if Self::is_signed() {
49 std::mem::size_of::<T>() * 8 - 1 - INTEGER_BITS
50 } else {
51 std::mem::size_of::<T>() * 8 - INTEGER_BITS
52 }
53 }
54
55 #[allow(dead_code)]
56 pub fn zero() -> Self {
57 Self(T::zero())
58 }
59
60 pub fn as_raw(&self) -> T {
61 self.0
62 }
63
64 pub fn is_signed() -> bool {
65 is_signed::<T>()
66 }
67}
68
69impl<T: 'static, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
70where
71 T: PrimInt + Debug,
72 usize: AsPrimitive<T>,
73{
74 pub fn constant_pot(exponent: isize) -> Self {
75 let offset = (Self::fractional_bits() as isize + exponent) as usize;
76 assert!(offset < 31);
77 Self(1_usize.as_() << offset)
78 }
79}
80
81impl FixedPoint<i32, 0> {
82 impl_fixed_point_func_unary!(exp_on_interval_between_negative_one_quarter_and_0_excl);
83 impl_fixed_point_func_unary!(one_over_one_plus_x_for_x_in_0_1);
84}
85
86impl FixedPoint<i32, 5> {
87 #[allow(dead_code)]
88 pub fn exp_on_negative_values(&self) -> FixedPoint<i32, 0> {
89 FixedPoint::<i32, 0>::from_raw(exp_on_negative_values(self.as_raw()))
90 }
91}
92
93impl<const INTEGER_BITS: usize> FixedPoint<i32, INTEGER_BITS> {
94 impl_fixed_point_func_unary!(mask_if_non_zero);
95 impl_fixed_point_func_unary!(mask_if_zero);
96 impl_fixed_point_func_binary!(rounding_half_sum);
97
98 pub fn saturating_rounding_multiply_by_pot(&self, exponent: i32) -> Self {
99 Self::from_raw(saturating_rounding_multiply_by_pot(self.as_raw(), exponent))
100 }
101
102 #[allow(dead_code)]
103 pub fn rounding_divide_by_pot(&self, exponent: i32) -> Self {
104 Self::from_raw(rounding_divide_by_pot(self.as_raw(), exponent))
105 }
106
107 pub fn select_using_mask(mask: i32, a: Self, b: Self) -> Self {
108 Self::from_raw(select_using_mask(mask, a.as_raw(), b.as_raw()))
109 }
110
111 pub fn rescale<const DST_INTEGER_BITS: usize>(&self) -> FixedPoint<i32, DST_INTEGER_BITS> {
112 FixedPoint::<i32, DST_INTEGER_BITS>::from_raw(rescale(
113 self.as_raw(),
114 INTEGER_BITS,
115 DST_INTEGER_BITS,
116 ))
117 }
118
119 #[allow(dead_code)]
120 pub fn get_reciprocal(&self) -> (FixedPoint<i32, 0>, usize) {
121 let (raw_res, num_bits_over_units) = get_reciprocal(self.as_raw(), INTEGER_BITS);
122 (FixedPoint::<i32, 0>::from_raw(raw_res), num_bits_over_units)
123 }
124}
125
126impl<T, const INTEGER_BITS: usize> Debug for FixedPoint<T, INTEGER_BITS>
127where
128 T: AsPrimitive<f32> + PrimInt + LowerHex + Debug + Binary,
129 f32: AsPrimitive<T>,
130{
131 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
132 write!(fmt, "{:032b}({:?})({})", self.0, self.0, self.as_f32())
133 }
134}
135
136impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
137where
138 T: AsPrimitive<f32> + PrimInt,
139{
140 pub fn as_f32(&self) -> f32 {
141 self.0.as_() / 2_f32.powi(Self::fractional_bits() as i32)
142 }
143}
144
145impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
146where
147 T: AsPrimitive<f32> + PrimInt,
148 f32: AsPrimitive<T>,
149{
150 #[allow(dead_code)]
151 pub fn from_f32(x: f32) -> Self {
152 Self::from_raw(
153 f32::min(
154 f32::max(
155 f32::round(x * 2f32.powi(Self::fractional_bits().as_())),
156 T::min_value().as_(),
157 ),
158 T::max_value().as_(),
159 )
160 .as_(),
161 )
162 }
163}
164
165impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Add for FixedPoint<T, INTEGER_BITS> {
166 type Output = FixedPoint<T, INTEGER_BITS>;
167 fn add(self, rhs: Self) -> Self::Output {
168 Self::from_raw(self.0 + rhs.0)
169 }
170}
171
172impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Sub for FixedPoint<T, INTEGER_BITS> {
173 type Output = FixedPoint<T, INTEGER_BITS>;
174 fn sub(self, rhs: Self) -> Self::Output {
175 Self::from_raw(self.0 - rhs.0)
176 }
177}
178
179impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Shl<usize> for FixedPoint<T, INTEGER_BITS> {
180 type Output = FixedPoint<T, INTEGER_BITS>;
181 fn shl(self, rhs: usize) -> Self::Output {
182 Self::from_raw(self.0 << rhs)
183 }
184}
185
186impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Shr<usize> for FixedPoint<T, INTEGER_BITS> {
187 type Output = FixedPoint<T, INTEGER_BITS>;
188 fn shr(self, rhs: usize) -> Self::Output {
189 Self::from_raw(self.0 >> rhs)
190 }
191}
192
193impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::BitAnd for FixedPoint<T, INTEGER_BITS> {
194 type Output = FixedPoint<T, INTEGER_BITS>;
195 fn bitand(self, rhs: Self) -> Self::Output {
196 Self::from_raw(self.0 & rhs.0)
197 }
198}
199
200macro_rules! impl_mul {
201 ($T: ty, $LHS_INTEGER_BITS: literal, $RHS_INTEGER_BITS: literal, $OUT_INTEGER_BITS: literal) => {
202 impl std::ops::Mul<FixedPoint<$T, $RHS_INTEGER_BITS>>
203 for FixedPoint<$T, $LHS_INTEGER_BITS>
204 {
205 type Output = FixedPoint<$T, $OUT_INTEGER_BITS>;
206 fn mul(self, rhs: FixedPoint<$T, $RHS_INTEGER_BITS>) -> Self::Output {
207 Self::Output::from_raw(saturating_rounding_doubling_high_mul(self.0, rhs.0))
208 }
209 }
210 };
211}
212
213impl_mul!(i32, 0, 0, 0);
214impl_mul!(i32, 0, 2, 2);
215impl_mul!(i32, 2, 0, 2);
216impl_mul!(i32, 2, 2, 4);
217impl_mul!(i32, 5, 5, 10);
218
219#[cfg(test)]
220mod test {
221 use super::*;
222 use approx::assert_abs_diff_eq;
223 pub type Q10_21 = FixedPoint<i32, 10>;
224 pub type Q12_19 = FixedPoint<i32, 12>;
225 pub type Q26_5 = FixedPoint<i32, 26>;
226 type Q0_7 = FixedPoint<i8, 0>;
227
228 #[test]
229 fn test_to_f32() {
230 let x = Q26_5::from_raw(32);
231 assert_eq!(x.as_f32(), 1.0);
232 }
233
234 #[test]
235 fn test_to_f32_1() {
236 let x = Q0_7::from_raw(32);
237 assert_eq!(x.as_f32(), 0.25);
238 }
239
240 #[test]
241 fn test_one() {
242 let x = Q26_5::one();
243 assert_eq!(x, Q26_5::from_raw(32));
244 }
245
246 #[test]
247 fn test_one_limit() {
248 let x = Q0_31::one();
249 assert_eq!(x, Q0_31::from_raw(i32::MAX));
250 }
251
252 #[test]
253 fn test_mul_1() {
254 let a = Q5_26::from_f32(8.0); let b = Q5_26::from_f32(3.0); let product = a * b;
257 let expected = Q10_21::from_f32(24.0);
258
259 assert_eq!(product, expected);
260 }
261
262 #[test]
263 fn test_add() {
264 let a = Q5_26::from_f32(16.0);
265 let b = Q5_26::from_f32(5.0);
266 let sum = a + b;
267 let expected = Q5_26::from_f32(21.0);
268 assert_eq!(sum, expected);
269 }
270
271 #[test]
272 fn test_one_over_one_plus_x_for_x_in_0_1() {
273 let a = Q0_31::from_f32(0.75);
274 let expected_res = Q0_31::from_f32(1.0 / 1.75);
275 let res = a.one_over_one_plus_x_for_x_in_0_1();
276 assert_eq!(res.as_f32(), expected_res.as_f32());
277 }
278
279 #[test]
280 fn test_one_over_one_plus_x_for_x_in_0_1_1() {
281 let a = Q0_31::from_f32(0.0);
282 let expected_res = Q0_31::from_f32(1.0 / 1.0);
283 let res = a.one_over_one_plus_x_for_x_in_0_1();
284 assert_eq!(res.as_f32(), expected_res.as_f32());
285 }
286
287 #[test]
288 fn test_get_reciprocal_1() {
289 let a = Q5_26::from_f32(4.5);
290 let expected_res = Q0_31::from_f32(1.0 / 4.5);
291 let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
292 let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
293 assert_eq!(res.as_f32(), expected_res.as_f32());
294 assert_eq!(num_bits_over_unit, 2);
295 }
296
297 #[test]
298 fn test_get_reciprocal_2() {
299 let a = Q5_26::from_f32(4.5);
300 let expected_res = Q0_31::from_f32(1.0 / 4.5);
301 let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
302 let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
303 assert_eq!(res.as_f32(), expected_res.as_f32());
304 assert_eq!(num_bits_over_unit, 2);
305 }
306
307 #[test]
308 fn test_get_reciprocal_3() {
309 let a = Q12_19::from_f32(2.0);
310 let expected_res = Q0_31::from_f32(1.0 / 2.0);
311 let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
312 let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
313 assert_eq!(res.as_f32(), expected_res.as_f32());
314 assert_eq!(num_bits_over_unit, 1);
315 }
316
317 #[test]
318 fn test_rescale_1() {
319 let a = Q0_31::from_f32(0.75);
320 let expeted_res = Q12_19::from_f32(0.75);
321 let res = a.rescale::<12>();
322 assert_eq!(res, expeted_res);
323 }
324
325 #[test]
326 fn test_exp_on_interval_between_negative_one_quarter_and_0_excl() {
327 let a = Q0_31::from_f32(-0.125);
328 let expected_res = Q0_31::from_f32((-0.125_f32).exp());
329 let res = a.exp_on_interval_between_negative_one_quarter_and_0_excl();
330 assert_eq!(res.as_f32(), expected_res.as_f32());
331 }
332
333 #[test]
334 fn test_exp_on_negative_values_1() {
335 let a = Q5_26::from_f32(-0.125);
336 let expected_res = Q0_31::from_f32((-0.125_f32).exp());
337 let res = a.exp_on_negative_values();
338 assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
339 }
340
341 #[test]
342 fn test_exp_on_negative_values_2() {
343 let a = Q5_26::from_f32(0.0);
344 let expected_res = Q0_31::from_f32((0_f32).exp());
345 let res = a.exp_on_negative_values();
346 assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
347 }
348
349 #[test]
350 fn test_exp_on_negative_values_3() {
351 let a = Q5_26::from_f32(-0.25);
352 let expected_res = Q0_31::from_f32((-0.25_f32).exp());
353 let res = a.exp_on_negative_values();
354 assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
355 }
356
357 #[test]
358 fn test_exp_on_negative_values_4() {
359 let a = Q5_26::from_f32(-1.1875_f32);
360 let expected_res = Q0_31::from_f32((-1.1875_f32).exp());
361 let res = a.exp_on_negative_values();
362 assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
363 }
364}