zyx_core/
scalar.rs

1use crate::dtype::DType;
2
3/// Scalar trait is implemented for all [dtypes](DType)
4pub trait Scalar: Clone + core::fmt::Debug + 'static {
5    /// Get dtype of Self
6    fn dtype() -> DType;
7    /// Get zero of Self
8    fn zero() -> Self;
9    /// Get one of Self
10    fn one() -> Self;
11    /// Bute size of Self
12    fn byte_size() -> usize;
13    /// Convert self into f32
14    fn into_f32(self) -> f32;
15    /// Convert self into f64
16    fn into_f64(self) -> f64;
17    /// Convert self into i32
18    fn into_i32(self) -> i32;
19    /// 1/self
20    fn reciprocal(self) -> Self;
21    /// Neg
22    fn neg(self) -> Self;
23    /// ReLU
24    fn relu(self) -> Self;
25    /// Sin
26    fn sin(self) -> Self;
27    /// Cos
28    fn cos(self) -> Self;
29    /// Ln
30    fn ln(self) -> Self;
31    /// Exp
32    fn exp(self) -> Self;
33    /// Tanh
34    fn tanh(self) -> Self;
35    /// Square root of this scalar.
36    /// That this function may be imprecise.
37    fn sqrt(self) -> Self;
38    /// Add
39    fn add(self, rhs: Self) -> Self;
40    /// Sub
41    fn sub(self, rhs: Self) -> Self;
42    /// Mul
43    fn mul(self, rhs: Self) -> Self;
44    /// Div
45    fn div(self, rhs: Self) -> Self;
46    /// Pow
47    fn pow(self, rhs: Self) -> Self;
48    /// Compare less than
49    fn cmplt(self, rhs: Self) -> Self;
50    /// Max of two numbers
51    fn max(self, rhs: Self) -> Self;
52    /// Max value of this dtype
53    fn max_value() -> Self;
54    /// Min value of this dtype
55    fn min_value() -> Self;
56    /// Very small value of scalar, very close to zero
57    fn epsilon() -> Self;
58    /// Comparison for scalars,
59    /// if they are floats, this checks for diffs > Self::epsilon()
60    fn is_equal(self, rhs: Self) -> bool;
61}
62
63impl Scalar for f32 {
64    fn dtype() -> DType {
65        DType::F32
66    }
67
68    fn zero() -> Self {
69        0.
70    }
71
72    fn one() -> Self {
73        1.
74    }
75
76    fn byte_size() -> usize {
77        4
78    }
79
80    fn into_f32(self) -> f32 {
81        self
82    }
83
84    fn into_f64(self) -> f64 {
85        self as f64
86    }
87
88    fn into_i32(self) -> i32 {
89        self as i32
90    }
91
92    fn reciprocal(self) -> Self {
93        1.0 / self
94    }
95
96    fn neg(self) -> Self {
97        -self
98    }
99
100    fn relu(self) -> Self {
101        self.max(0.)
102    }
103
104    fn sin(self) -> Self {
105        f32::sin(self)
106    }
107
108    fn cos(self) -> Self {
109        f32::cos(self)
110    }
111
112    fn exp(self) -> Self {
113        f32::exp(self)
114    }
115
116    fn ln(self) -> Self {
117        f32::ln(self)
118    }
119
120    fn tanh(self) -> Self {
121        f32::tanh(self)
122    }
123
124    fn sqrt(self) -> Self {
125        // good enough (error of ~ 5%)
126        if self >= 0. {
127            Self::from_bits((self.to_bits() + 0x3f80_0000) >> 1)
128        } else {
129            Self::NAN
130        }
131    }
132
133    fn add(self, rhs: Self) -> Self {
134        self + rhs
135    }
136
137    fn sub(self, rhs: Self) -> Self {
138        self - rhs
139    }
140
141    fn mul(self, rhs: Self) -> Self {
142        self * rhs
143    }
144
145    fn div(self, rhs: Self) -> Self {
146        self / rhs
147    }
148
149    fn pow(self, rhs: Self) -> Self {
150        f32::powf(self, rhs)
151    }
152
153    fn cmplt(self, rhs: Self) -> Self {
154        (self < rhs) as i32 as f32
155    }
156
157    fn max(self, rhs: Self) -> Self {
158        f32::max(self, rhs)
159    }
160
161    fn max_value() -> Self {
162        f32::MAX
163    }
164
165    fn min_value() -> Self {
166        f32::MIN
167    }
168
169    fn epsilon() -> Self {
170        0.00001
171    }
172
173    fn is_equal(self, rhs: Self) -> bool {
174        // Less than 1% error is OK
175        (self == -f32::INFINITY && rhs == -f32::INFINITY)
176            || (self - rhs).abs() < Self::epsilon()
177            || (self - rhs).abs() < self.abs() * 0.01
178    }
179}
180
181impl Scalar for f64 {
182    fn dtype() -> DType {
183        DType::F64
184    }
185
186    fn zero() -> Self {
187        0.
188    }
189
190    fn one() -> Self {
191        1.
192    }
193
194    fn byte_size() -> usize {
195        8
196    }
197
198    fn into_f32(self) -> f32 {
199        self as f32
200    }
201
202    fn into_f64(self) -> f64 {
203        self
204    }
205
206    fn into_i32(self) -> i32 {
207        self as i32
208    }
209
210    fn reciprocal(self) -> Self {
211        1.0 / self
212    }
213
214    fn neg(self) -> Self {
215        -self
216    }
217
218    fn relu(self) -> Self {
219        self.max(0.)
220    }
221
222    fn sin(self) -> Self {
223        f64::sin(self)
224    }
225
226    fn cos(self) -> Self {
227        f64::cos(self)
228    }
229
230    fn exp(self) -> Self {
231        f64::exp(self)
232    }
233
234    fn ln(self) -> Self {
235        f64::ln(self)
236    }
237
238    fn tanh(self) -> Self {
239        f64::tanh(self)
240    }
241
242    fn sqrt(self) -> Self {
243        f64::sqrt(self)
244    }
245
246    fn add(self, rhs: Self) -> Self {
247        self + rhs
248    }
249
250    fn sub(self, rhs: Self) -> Self {
251        self - rhs
252    }
253
254    fn mul(self, rhs: Self) -> Self {
255        self * rhs
256    }
257
258    fn div(self, rhs: Self) -> Self {
259        self / rhs
260    }
261
262    fn pow(self, rhs: Self) -> Self {
263        f64::powf(self, rhs)
264    }
265
266    fn cmplt(self, rhs: Self) -> Self {
267        (self < rhs) as i32 as f64
268    }
269
270    fn max(self, rhs: Self) -> Self {
271        f64::max(self, rhs)
272    }
273
274    fn max_value() -> Self {
275        f64::MAX
276    }
277
278    fn min_value() -> Self {
279        f64::MIN
280    }
281
282    fn epsilon() -> Self {
283        0.00001
284    }
285
286    fn is_equal(self, rhs: Self) -> bool {
287        // Less than 1% error is OK
288        (self == -f64::INFINITY && rhs == -f64::INFINITY)
289            || (self - rhs).abs() < Self::epsilon()
290            || (self - rhs).abs() < self.abs() * 0.01
291    }
292}
293
294impl Scalar for i32 {
295    fn dtype() -> DType {
296        DType::I32
297    }
298
299    fn zero() -> Self {
300        0
301    }
302
303    fn one() -> Self {
304        1
305    }
306
307    fn byte_size() -> usize {
308        4
309    }
310
311    fn into_f32(self) -> f32 {
312        self as f32
313    }
314
315    fn into_f64(self) -> f64 {
316        self as f64
317    }
318
319    fn into_i32(self) -> i32 {
320        self
321    }
322
323    fn reciprocal(self) -> Self {
324        1 / self
325    }
326
327    fn neg(self) -> Self {
328        -self
329    }
330
331    fn relu(self) -> Self {
332        <i32 as Ord>::max(self, 0)
333    }
334
335    fn sin(self) -> Self {
336        f32::sin(self as f32) as i32
337    }
338
339    fn cos(self) -> Self {
340        f32::cos(self as f32) as i32
341    }
342
343    fn exp(self) -> Self {
344        f32::exp(self as f32) as i32
345    }
346
347    fn ln(self) -> Self {
348        f32::ln(self as f32) as i32
349    }
350
351    fn tanh(self) -> Self {
352        f32::tanh(self as f32) as i32
353    }
354
355    fn sqrt(self) -> Self {
356        (self as f32).sqrt() as i32
357    }
358
359    fn add(self, rhs: Self) -> Self {
360        self + rhs
361    }
362
363    fn sub(self, rhs: Self) -> Self {
364        self - rhs
365    }
366
367    fn mul(self, rhs: Self) -> Self {
368        self * rhs
369    }
370
371    fn div(self, rhs: Self) -> Self {
372        self / rhs
373    }
374
375    fn pow(self, rhs: Self) -> Self {
376        i32::pow(self, rhs as u32)
377    }
378
379    fn cmplt(self, rhs: Self) -> Self {
380        (self < rhs) as i32
381    }
382
383    fn max(self, rhs: Self) -> Self {
384        <i32 as Ord>::max(self, rhs)
385    }
386
387    fn max_value() -> Self {
388        i32::MAX
389    }
390
391    fn min_value() -> Self {
392        i32::MIN
393    }
394
395    fn epsilon() -> Self {
396        0
397    }
398
399    fn is_equal(self, rhs: Self) -> bool {
400        self == rhs
401    }
402}