1use crate::dtype::DType;
2
3pub trait Scalar: Clone + core::fmt::Debug + 'static {
5 fn dtype() -> DType;
7 fn zero() -> Self;
9 fn one() -> Self;
11 fn byte_size() -> usize;
13 fn into_f32(self) -> f32;
15 fn into_f64(self) -> f64;
17 fn into_i32(self) -> i32;
19 fn reciprocal(self) -> Self;
21 fn neg(self) -> Self;
23 fn relu(self) -> Self;
25 fn sin(self) -> Self;
27 fn cos(self) -> Self;
29 fn ln(self) -> Self;
31 fn exp(self) -> Self;
33 fn tanh(self) -> Self;
35 fn sqrt(self) -> Self;
38 fn add(self, rhs: Self) -> Self;
40 fn sub(self, rhs: Self) -> Self;
42 fn mul(self, rhs: Self) -> Self;
44 fn div(self, rhs: Self) -> Self;
46 fn pow(self, rhs: Self) -> Self;
48 fn cmplt(self, rhs: Self) -> Self;
50 fn max(self, rhs: Self) -> Self;
52 fn max_value() -> Self;
54 fn min_value() -> Self;
56 fn epsilon() -> Self;
58 fn is_equal(self, rhs: Self) -> bool;
61}
62
63impl Scalar for f32 {
64 fn dtype() -> DType {
65 DType::F32
66 }
67
68 fn zero() -> Self {
69 0.
70 }
71
72 fn one() -> Self {
73 1.
74 }
75
76 fn byte_size() -> usize {
77 4
78 }
79
80 fn into_f32(self) -> f32 {
81 self
82 }
83
84 fn into_f64(self) -> f64 {
85 self as f64
86 }
87
88 fn into_i32(self) -> i32 {
89 self as i32
90 }
91
92 fn reciprocal(self) -> Self {
93 1.0 / self
94 }
95
96 fn neg(self) -> Self {
97 -self
98 }
99
100 fn relu(self) -> Self {
101 self.max(0.)
102 }
103
104 fn sin(self) -> Self {
105 f32::sin(self)
106 }
107
108 fn cos(self) -> Self {
109 f32::cos(self)
110 }
111
112 fn exp(self) -> Self {
113 f32::exp(self)
114 }
115
116 fn ln(self) -> Self {
117 f32::ln(self)
118 }
119
120 fn tanh(self) -> Self {
121 f32::tanh(self)
122 }
123
124 fn sqrt(self) -> Self {
125 if self >= 0. {
127 Self::from_bits((self.to_bits() + 0x3f80_0000) >> 1)
128 } else {
129 Self::NAN
130 }
131 }
132
133 fn add(self, rhs: Self) -> Self {
134 self + rhs
135 }
136
137 fn sub(self, rhs: Self) -> Self {
138 self - rhs
139 }
140
141 fn mul(self, rhs: Self) -> Self {
142 self * rhs
143 }
144
145 fn div(self, rhs: Self) -> Self {
146 self / rhs
147 }
148
149 fn pow(self, rhs: Self) -> Self {
150 f32::powf(self, rhs)
151 }
152
153 fn cmplt(self, rhs: Self) -> Self {
154 (self < rhs) as i32 as f32
155 }
156
157 fn max(self, rhs: Self) -> Self {
158 f32::max(self, rhs)
159 }
160
161 fn max_value() -> Self {
162 f32::MAX
163 }
164
165 fn min_value() -> Self {
166 f32::MIN
167 }
168
169 fn epsilon() -> Self {
170 0.00001
171 }
172
173 fn is_equal(self, rhs: Self) -> bool {
174 (self == -f32::INFINITY && rhs == -f32::INFINITY)
176 || (self - rhs).abs() < Self::epsilon()
177 || (self - rhs).abs() < self.abs() * 0.01
178 }
179}
180
181impl Scalar for f64 {
182 fn dtype() -> DType {
183 DType::F64
184 }
185
186 fn zero() -> Self {
187 0.
188 }
189
190 fn one() -> Self {
191 1.
192 }
193
194 fn byte_size() -> usize {
195 8
196 }
197
198 fn into_f32(self) -> f32 {
199 self as f32
200 }
201
202 fn into_f64(self) -> f64 {
203 self
204 }
205
206 fn into_i32(self) -> i32 {
207 self as i32
208 }
209
210 fn reciprocal(self) -> Self {
211 1.0 / self
212 }
213
214 fn neg(self) -> Self {
215 -self
216 }
217
218 fn relu(self) -> Self {
219 self.max(0.)
220 }
221
222 fn sin(self) -> Self {
223 f64::sin(self)
224 }
225
226 fn cos(self) -> Self {
227 f64::cos(self)
228 }
229
230 fn exp(self) -> Self {
231 f64::exp(self)
232 }
233
234 fn ln(self) -> Self {
235 f64::ln(self)
236 }
237
238 fn tanh(self) -> Self {
239 f64::tanh(self)
240 }
241
242 fn sqrt(self) -> Self {
243 f64::sqrt(self)
244 }
245
246 fn add(self, rhs: Self) -> Self {
247 self + rhs
248 }
249
250 fn sub(self, rhs: Self) -> Self {
251 self - rhs
252 }
253
254 fn mul(self, rhs: Self) -> Self {
255 self * rhs
256 }
257
258 fn div(self, rhs: Self) -> Self {
259 self / rhs
260 }
261
262 fn pow(self, rhs: Self) -> Self {
263 f64::powf(self, rhs)
264 }
265
266 fn cmplt(self, rhs: Self) -> Self {
267 (self < rhs) as i32 as f64
268 }
269
270 fn max(self, rhs: Self) -> Self {
271 f64::max(self, rhs)
272 }
273
274 fn max_value() -> Self {
275 f64::MAX
276 }
277
278 fn min_value() -> Self {
279 f64::MIN
280 }
281
282 fn epsilon() -> Self {
283 0.00001
284 }
285
286 fn is_equal(self, rhs: Self) -> bool {
287 (self == -f64::INFINITY && rhs == -f64::INFINITY)
289 || (self - rhs).abs() < Self::epsilon()
290 || (self - rhs).abs() < self.abs() * 0.01
291 }
292}
293
294impl Scalar for i32 {
295 fn dtype() -> DType {
296 DType::I32
297 }
298
299 fn zero() -> Self {
300 0
301 }
302
303 fn one() -> Self {
304 1
305 }
306
307 fn byte_size() -> usize {
308 4
309 }
310
311 fn into_f32(self) -> f32 {
312 self as f32
313 }
314
315 fn into_f64(self) -> f64 {
316 self as f64
317 }
318
319 fn into_i32(self) -> i32 {
320 self
321 }
322
323 fn reciprocal(self) -> Self {
324 1 / self
325 }
326
327 fn neg(self) -> Self {
328 -self
329 }
330
331 fn relu(self) -> Self {
332 <i32 as Ord>::max(self, 0)
333 }
334
335 fn sin(self) -> Self {
336 f32::sin(self as f32) as i32
337 }
338
339 fn cos(self) -> Self {
340 f32::cos(self as f32) as i32
341 }
342
343 fn exp(self) -> Self {
344 f32::exp(self as f32) as i32
345 }
346
347 fn ln(self) -> Self {
348 f32::ln(self as f32) as i32
349 }
350
351 fn tanh(self) -> Self {
352 f32::tanh(self as f32) as i32
353 }
354
355 fn sqrt(self) -> Self {
356 (self as f32).sqrt() as i32
357 }
358
359 fn add(self, rhs: Self) -> Self {
360 self + rhs
361 }
362
363 fn sub(self, rhs: Self) -> Self {
364 self - rhs
365 }
366
367 fn mul(self, rhs: Self) -> Self {
368 self * rhs
369 }
370
371 fn div(self, rhs: Self) -> Self {
372 self / rhs
373 }
374
375 fn pow(self, rhs: Self) -> Self {
376 i32::pow(self, rhs as u32)
377 }
378
379 fn cmplt(self, rhs: Self) -> Self {
380 (self < rhs) as i32
381 }
382
383 fn max(self, rhs: Self) -> Self {
384 <i32 as Ord>::max(self, rhs)
385 }
386
387 fn max_value() -> Self {
388 i32::MAX
389 }
390
391 fn min_value() -> Self {
392 i32::MIN
393 }
394
395 fn epsilon() -> Self {
396 0
397 }
398
399 fn is_equal(self, rhs: Self) -> bool {
400 self == rhs
401 }
402}