tea_dtype/
number.rs

1use std::cmp::PartialOrd;
2use std::ops::{Add, AddAssign, DivAssign, MulAssign, Sub, SubAssign};
3
4use num_traits::{MulAdd, Num};
5
6use super::cast::Cast;
7use super::isnone::IsNone;
8
9/// Kahan summation, see https://en.wikipedia.org/wiki/Kahan_summation_algorithm
10#[inline]
11fn kh_sum<T>(sum: T, v: T, c: &mut T) -> T
12where
13    T: Add<Output = T> + Sub<Output = T> + Copy,
14{
15    let y = v - *c;
16    let t = sum + y;
17    *c = (t - sum) - y;
18    t
19}
20/// A trait representing numeric types with various operations and conversions.
21///
22/// This trait combines several other traits and provides additional functionality
23/// for numeric types. It includes operations for arithmetic, comparison, conversion,
24/// and special numeric functions.
25///
26/// # Type Constraints
27///
28/// The type implementing this trait must satisfy the following constraints:
29/// - `Copy`: The type can be copied bit-for-bit.
30/// - `Send`: The type can be safely transferred across thread boundaries.
31/// - `Sync`: The type can be safely shared between threads.
32/// - `IsNone`: The type has a concept of a "none" value.
33/// - `Sized`: The type has a known size at compile-time.
34/// - `Default`: The type has a default value.
35/// - `Num`: The type supports basic numeric operations.
36/// - `AddAssign`, `SubAssign`, `MulAssign`, `DivAssign`: The type supports compound assignment operations.
37/// - `PartialOrd`: The type can be partially ordered.
38/// - `MulAdd`: The type supports fused multiply-add operations.
39/// - `Cast<f64>`, `Cast<f32>`, `Cast<usize>`, `Cast<i32>`, `Cast<i64>`: The type can be cast to these numeric types.
40/// - `'static`: The type has a static lifetime.
41pub trait Number:
42    Copy
43    // + Clone
44    + Send
45    + Sync
46    + IsNone
47    + Sized
48    + Default
49    + Num
50    + AddAssign
51    + SubAssign
52    + MulAssign
53    + DivAssign
54    + PartialOrd
55    + MulAdd
56    + Cast<f64>
57    + Cast<f32>
58    + Cast<usize>
59    + Cast<i32>
60    + Cast<i64>
61    + 'static
62{
63    // type Dtype;
64    /// Returns the minimum value of the data type.
65    fn min_() -> Self;
66
67    /// Returns the maximum value of the data type.
68    fn max_() -> Self;
69
70    /// Computes the absolute value of the number.
71    fn abs(self) -> Self;
72
73    /// Computes the ceiling of the number.
74    ///
75    /// For integer types, this is typically the identity function.
76    #[inline(always)]
77    fn ceil(self) -> Self {
78        self
79    }
80
81    /// Computes the floor of the number.
82    ///
83    /// For integer types, this is typically the identity function.
84    #[inline(always)]
85    fn floor(self) -> Self {
86        self
87    }
88
89    /// Returns the minimum of self and other.
90    #[inline]
91    fn min_with(self, other: Self) -> Self {
92        if other < self {
93            other
94        } else {
95            self
96        }
97    }
98
99    /// Returns the maximum of self and other.
100    #[inline]
101    fn max_with(self, other: Self) -> Self {
102        if other > self {
103            other
104        } else {
105            self
106        }
107    }
108
109    /// Casts the number to f32.
110    #[inline(always)]
111    fn f32(self) -> f32 {
112        Cast::<f32>::cast(self)
113    }
114
115    /// Casts the number to f64.
116    #[inline(always)]
117    fn f64(self) -> f64 {
118        Cast::<f64>::cast(self)
119    }
120
121    /// Casts the number to i32.
122    #[inline(always)]
123    fn i32(self) -> i32 {
124        Cast::<i32>::cast(self)
125    }
126
127    /// Casts the number to i64.
128    #[inline(always)]
129    fn i64(self) -> i64 {
130        Cast::<i64>::cast(self)
131    }
132
133    /// Casts the number to usize.
134    #[inline(always)]
135    fn usize(self) -> usize {
136        Cast::<usize>::cast(self)
137    }
138
139    /// Creates a value of type Self using a value of type U using `Cast`.
140    #[inline(always)]
141    fn fromas<U>(v: U) -> Self
142    where
143        U: Number + Cast<Self>,
144        Self: 'static,
145    {
146        v.to::<Self>()
147    }
148
149    /// Casts self to another type T using `Cast`.
150    #[inline(always)]
151    fn to<T: Number>(self) -> T
152    where
153        Self: Cast<T>,
154    {
155        Cast::<T>::cast(self)
156    }
157
158    /// Performs Kahan summation.
159    ///
160    /// This method implements the Kahan summation algorithm, which helps reduce
161    /// numerical error in the sum of a sequence of floating point numbers.
162    #[inline(always)]
163    #[must_use]
164    fn kh_sum(self, v: Self, c: &mut Self) -> Self {
165        kh_sum(self, v, c)
166    }
167
168    /// Conditionally adds `other` to `self` and increments `n`.
169    ///
170    /// If `other` is not none, it adds `other` to `self` and increments `n`.
171    /// Otherwise, it returns `self` unchanged.
172    #[inline]
173    fn n_add(self, other: Self, n: &mut usize) -> Self {
174        // note: only check if other is NaN
175        // assume that self is not NaN
176        if other.not_none() {
177            *n += 1;
178            self + other
179        } else {
180            self
181        }
182    }
183
184    /// Conditionally multiplies `self` by `other` and increments `n`.
185    ///
186    /// If `other` is not none, it multiplies `self` by `other` and increments `n`.
187    /// Otherwise, it returns `self` unchanged.
188    #[inline]
189    fn n_prod(self, other: Self, n: &mut usize) -> Self {
190        // note: only check if other is NaN
191        // assume that self is not NaN
192        if other.not_none() {
193            *n += 1;
194            self * other
195        } else {
196            self
197        }
198    }
199}
200
201macro_rules! impl_number {
202    (@ base_impl $dtype:ty, $datatype:ident) => {
203
204        #[inline(always)]
205        fn min_() -> $dtype {
206            <$dtype>::MIN
207        }
208
209        #[inline(always)]
210        fn max_() -> $dtype {
211            <$dtype>::MAX
212        }
213
214    };
215    // special impl for float
216    (float $($dtype:ty, $datatype:ident); *) => {
217        $(impl Number for $dtype {
218            impl_number!(@ base_impl $dtype, $datatype);
219
220            #[inline]
221            fn ceil(self) -> Self {
222                self.ceil()
223            }
224
225            #[inline]
226            fn floor(self) -> Self {
227                self.floor()
228            }
229
230            #[inline]
231            fn abs(self) -> Self {
232                self.abs()
233            }
234
235        })*
236    };
237    // special impl for other type
238    (signed $($dtype:ty, $datatype:ident); *) => {
239        $(impl Number for $dtype {
240            impl_number!(@ base_impl $dtype, $datatype);
241            #[inline]
242            fn abs(self) -> Self {
243                self.abs()
244            }
245        })*
246    };
247    // special impl for other type
248    (unsigned $($dtype:ty, $datatype:ident); *) => {
249        $(impl Number for $dtype {
250            impl_number!(@ base_impl $dtype, $datatype);
251            #[inline]
252            fn abs(self) -> Self {
253                self
254            }
255        })*
256    };
257}
258
259impl_number!(
260    float
261    f32, F32;
262    f64, F64
263);
264
265impl_number!(
266    signed
267    i32, I32;
268    i64, I64
269);
270
271impl_number!(
272    unsigned
273    u64, U64;
274    usize, Usize
275);
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    #[test]
281    fn test_ceil() {
282        fn _ceil<T: Number>(v: T) -> T {
283            v.ceil()
284        }
285        assert_eq!(_ceil(1.23_f64), 2.);
286        assert_eq!(_ceil(-1.23_f32), -1.);
287        assert_eq!(_ceil(0_usize), 0);
288        assert_eq!(_ceil(-3i32), -3);
289    }
290
291    #[test]
292    fn test_floor() {
293        fn _floor<T: Number>(v: T) -> T {
294            v.floor()
295        }
296        assert_eq!(_floor(1.23_f64), 1.);
297        assert_eq!(_floor(-1.23_f32), -2.);
298        assert_eq!(_floor(0_usize), 0);
299        assert_eq!(_floor(-3i32), -3);
300    }
301}