tea_dtype/number.rs
1use std::cmp::PartialOrd;
2use std::ops::{Add, AddAssign, DivAssign, MulAssign, Sub, SubAssign};
3
4use num_traits::{MulAdd, Num};
5
6use super::cast::Cast;
7use super::isnone::IsNone;
8
9/// Kahan summation, see https://en.wikipedia.org/wiki/Kahan_summation_algorithm
10#[inline]
11fn kh_sum<T>(sum: T, v: T, c: &mut T) -> T
12where
13 T: Add<Output = T> + Sub<Output = T> + Copy,
14{
15 let y = v - *c;
16 let t = sum + y;
17 *c = (t - sum) - y;
18 t
19}
20/// A trait representing numeric types with various operations and conversions.
21///
22/// This trait combines several other traits and provides additional functionality
23/// for numeric types. It includes operations for arithmetic, comparison, conversion,
24/// and special numeric functions.
25///
26/// # Type Constraints
27///
28/// The type implementing this trait must satisfy the following constraints:
29/// - `Copy`: The type can be copied bit-for-bit.
30/// - `Send`: The type can be safely transferred across thread boundaries.
31/// - `Sync`: The type can be safely shared between threads.
32/// - `IsNone`: The type has a concept of a "none" value.
33/// - `Sized`: The type has a known size at compile-time.
34/// - `Default`: The type has a default value.
35/// - `Num`: The type supports basic numeric operations.
36/// - `AddAssign`, `SubAssign`, `MulAssign`, `DivAssign`: The type supports compound assignment operations.
37/// - `PartialOrd`: The type can be partially ordered.
38/// - `MulAdd`: The type supports fused multiply-add operations.
39/// - `Cast<f64>`, `Cast<f32>`, `Cast<usize>`, `Cast<i32>`, `Cast<i64>`: The type can be cast to these numeric types.
40/// - `'static`: The type has a static lifetime.
41pub trait Number:
42 Copy
43 // + Clone
44 + Send
45 + Sync
46 + IsNone
47 + Sized
48 + Default
49 + Num
50 + AddAssign
51 + SubAssign
52 + MulAssign
53 + DivAssign
54 + PartialOrd
55 + MulAdd
56 + Cast<f64>
57 + Cast<f32>
58 + Cast<usize>
59 + Cast<i32>
60 + Cast<i64>
61 + 'static
62{
63 // type Dtype;
64 /// Returns the minimum value of the data type.
65 fn min_() -> Self;
66
67 /// Returns the maximum value of the data type.
68 fn max_() -> Self;
69
70 /// Computes the absolute value of the number.
71 fn abs(self) -> Self;
72
73 /// Computes the ceiling of the number.
74 ///
75 /// For integer types, this is typically the identity function.
76 #[inline(always)]
77 fn ceil(self) -> Self {
78 self
79 }
80
81 /// Computes the floor of the number.
82 ///
83 /// For integer types, this is typically the identity function.
84 #[inline(always)]
85 fn floor(self) -> Self {
86 self
87 }
88
89 /// Returns the minimum of self and other.
90 #[inline]
91 fn min_with(self, other: Self) -> Self {
92 if other < self {
93 other
94 } else {
95 self
96 }
97 }
98
99 /// Returns the maximum of self and other.
100 #[inline]
101 fn max_with(self, other: Self) -> Self {
102 if other > self {
103 other
104 } else {
105 self
106 }
107 }
108
109 /// Casts the number to f32.
110 #[inline(always)]
111 fn f32(self) -> f32 {
112 Cast::<f32>::cast(self)
113 }
114
115 /// Casts the number to f64.
116 #[inline(always)]
117 fn f64(self) -> f64 {
118 Cast::<f64>::cast(self)
119 }
120
121 /// Casts the number to i32.
122 #[inline(always)]
123 fn i32(self) -> i32 {
124 Cast::<i32>::cast(self)
125 }
126
127 /// Casts the number to i64.
128 #[inline(always)]
129 fn i64(self) -> i64 {
130 Cast::<i64>::cast(self)
131 }
132
133 /// Casts the number to usize.
134 #[inline(always)]
135 fn usize(self) -> usize {
136 Cast::<usize>::cast(self)
137 }
138
139 /// Creates a value of type Self using a value of type U using `Cast`.
140 #[inline(always)]
141 fn fromas<U>(v: U) -> Self
142 where
143 U: Number + Cast<Self>,
144 Self: 'static,
145 {
146 v.to::<Self>()
147 }
148
149 /// Casts self to another type T using `Cast`.
150 #[inline(always)]
151 fn to<T: Number>(self) -> T
152 where
153 Self: Cast<T>,
154 {
155 Cast::<T>::cast(self)
156 }
157
158 /// Performs Kahan summation.
159 ///
160 /// This method implements the Kahan summation algorithm, which helps reduce
161 /// numerical error in the sum of a sequence of floating point numbers.
162 #[inline(always)]
163 #[must_use]
164 fn kh_sum(self, v: Self, c: &mut Self) -> Self {
165 kh_sum(self, v, c)
166 }
167
168 /// Conditionally adds `other` to `self` and increments `n`.
169 ///
170 /// If `other` is not none, it adds `other` to `self` and increments `n`.
171 /// Otherwise, it returns `self` unchanged.
172 #[inline]
173 fn n_add(self, other: Self, n: &mut usize) -> Self {
174 // note: only check if other is NaN
175 // assume that self is not NaN
176 if other.not_none() {
177 *n += 1;
178 self + other
179 } else {
180 self
181 }
182 }
183
184 /// Conditionally multiplies `self` by `other` and increments `n`.
185 ///
186 /// If `other` is not none, it multiplies `self` by `other` and increments `n`.
187 /// Otherwise, it returns `self` unchanged.
188 #[inline]
189 fn n_prod(self, other: Self, n: &mut usize) -> Self {
190 // note: only check if other is NaN
191 // assume that self is not NaN
192 if other.not_none() {
193 *n += 1;
194 self * other
195 } else {
196 self
197 }
198 }
199}
200
201macro_rules! impl_number {
202 (@ base_impl $dtype:ty, $datatype:ident) => {
203
204 #[inline(always)]
205 fn min_() -> $dtype {
206 <$dtype>::MIN
207 }
208
209 #[inline(always)]
210 fn max_() -> $dtype {
211 <$dtype>::MAX
212 }
213
214 };
215 // special impl for float
216 (float $($dtype:ty, $datatype:ident); *) => {
217 $(impl Number for $dtype {
218 impl_number!(@ base_impl $dtype, $datatype);
219
220 #[inline]
221 fn ceil(self) -> Self {
222 self.ceil()
223 }
224
225 #[inline]
226 fn floor(self) -> Self {
227 self.floor()
228 }
229
230 #[inline]
231 fn abs(self) -> Self {
232 self.abs()
233 }
234
235 })*
236 };
237 // special impl for other type
238 (signed $($dtype:ty, $datatype:ident); *) => {
239 $(impl Number for $dtype {
240 impl_number!(@ base_impl $dtype, $datatype);
241 #[inline]
242 fn abs(self) -> Self {
243 self.abs()
244 }
245 })*
246 };
247 // special impl for other type
248 (unsigned $($dtype:ty, $datatype:ident); *) => {
249 $(impl Number for $dtype {
250 impl_number!(@ base_impl $dtype, $datatype);
251 #[inline]
252 fn abs(self) -> Self {
253 self
254 }
255 })*
256 };
257}
258
259impl_number!(
260 float
261 f32, F32;
262 f64, F64
263);
264
265impl_number!(
266 signed
267 i32, I32;
268 i64, I64
269);
270
271impl_number!(
272 unsigned
273 u64, U64;
274 usize, Usize
275);
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 #[test]
281 fn test_ceil() {
282 fn _ceil<T: Number>(v: T) -> T {
283 v.ceil()
284 }
285 assert_eq!(_ceil(1.23_f64), 2.);
286 assert_eq!(_ceil(-1.23_f32), -1.);
287 assert_eq!(_ceil(0_usize), 0);
288 assert_eq!(_ceil(-3i32), -3);
289 }
290
291 #[test]
292 fn test_floor() {
293 fn _floor<T: Number>(v: T) -> T {
294 v.floor()
295 }
296 assert_eq!(_floor(1.23_f64), 1.);
297 assert_eq!(_floor(-1.23_f32), -2.);
298 assert_eq!(_floor(0_usize), 0);
299 assert_eq!(_floor(-3i32), -3);
300 }
301}