Skip to main content

zyx/
scalar.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4//! Trait describing required operations on scalar values
5
6use crate::dtype::DType;
7use half::{bf16, f16};
8
9/// Scalar trait is implemented for all [dtypes](DType)
10pub trait Scalar: Copy + Clone + Sized + core::fmt::Debug + 'static + PartialEq + Send + Sync + PartialOrd {
11    /// From bf16
12    #[must_use]
13    fn from_bf16(t: bf16) -> Self;
14    /// From f16
15    #[must_use]
16    fn from_f16(t: f16) -> Self;
17    /// From f32
18    #[must_use]
19    fn from_f32(t: f32) -> Self;
20    /// From f64
21    #[must_use]
22    fn from_f64(t: f64) -> Self;
23    /// From u8
24    #[must_use]
25    fn from_u8(t: u8) -> Self;
26    /// From u16
27    #[must_use]
28    fn from_u16(t: u16) -> Self;
29    /// From u32
30    #[must_use]
31    fn from_u32(t: u32) -> Self;
32    /// From u64
33    #[must_use]
34    fn from_u64(t: u64) -> Self;
35    /// From i8
36    #[must_use]
37    fn from_i8(t: i8) -> Self;
38    /// From i16
39    fn from_i16(t: i16) -> Self;
40    #[must_use]
41    /// From i32
42    fn from_i32(t: i32) -> Self;
43    /// From i64
44    #[must_use]
45    fn from_i64(t: i64) -> Self;
46    /// From bool
47    #[must_use]
48    fn from_bool(t: bool) -> Self;
49    /// From little endian bytes
50    #[must_use]
51    fn from_le_bytes(bytes: &[u8]) -> Self;
52    /// To native endian bytes
53    #[must_use]
54    #[allow(clippy::ptr_as_ptr)]
55    fn to_ne_bytes(&self) -> &[u8];
56    /// Get size in bits
57    #[must_use]
58    fn bit_size() -> u8 {
59        Self::dtype().bit_size()
60    }
61    /// Get dtype of Self
62    #[must_use]
63    fn dtype() -> DType;
64    /// Get zero of Self
65    #[must_use]
66    fn zero() -> Self;
67    /// Get one of Self
68    #[must_use]
69    fn one() -> Self;
70    /// Absolute value of self
71    #[must_use]
72    fn abs(self) -> Self;
73    /// Neg
74    #[must_use]
75    fn neg(self) -> Self;
76    /// Exp 2
77    #[must_use]
78    fn exp2(self) -> Self;
79    /// Log 2
80    #[must_use]
81    fn log2(self) -> Self;
82    /// `ReLU`
83    #[must_use]
84    fn relu(self) -> Self;
85    /// Not
86    #[must_use]
87    fn not(self) -> Self;
88    /// Nonzero
89    #[must_use]
90    fn nonzero(self) -> Self;
91    /// Add
92    #[must_use]
93    fn add(self, rhs: Self) -> Self;
94    /// Sub
95    #[must_use]
96    fn sub(self, rhs: Self) -> Self;
97    /// Mul
98    #[must_use]
99    fn mul(self, rhs: Self) -> Self;
100    /// Div
101    #[must_use]
102    fn div(self, rhs: Self) -> Self;
103    /// Pow
104    #[must_use]
105    fn pow(self, rhs: Self) -> Self;
106    /// Mod
107    #[must_use]
108    fn mod_(self, rhs: Self) -> Self;
109    /// Compare less than
110    #[must_use]
111    fn cmplt(self, rhs: Self) -> bool;
112    /// Compare less than
113    #[must_use]
114    fn cmpgt(self, rhs: Self) -> bool;
115    /// Noteq
116    #[must_use]
117    fn noteq(self, rhs: Self) -> bool;
118    /// Compare less than
119    #[must_use]
120    fn or(self, rhs: Self) -> bool;
121    /// Bitxor
122    #[must_use]
123    fn bitxor(self, rhs: Self) -> Self;
124    /// Bitor
125    #[must_use]
126    fn bitor(self, rhs: Self) -> Self;
127    /// Bitand
128    #[must_use]
129    fn bitand(self, rhs: Self) -> Self;
130    /// Bit shift left
131    #[must_use]
132    fn bitshiftleft(self, rhs: Self) -> Self;
133    /// Bit shift rigt
134    #[must_use]
135    fn bitshiftright(self, rhs: Self) -> Self;
136    /// And
137    #[must_use]
138    fn and(self, rhs: Self) -> bool;
139    /// Max of two numbers
140    #[must_use]
141    fn max(self, rhs: Self) -> Self;
142    /// Max value of this dtype
143    #[must_use]
144    fn max_value() -> Self;
145    /// Min value of this dtype
146    #[must_use]
147    fn min_value() -> Self;
148    /// Comparison for scalars,
149    /// if they are floats, this checks for diffs > `Self::epsilon()`
150    #[must_use]
151    fn is_equal(self, rhs: Self) -> bool;
152    /// Cast into different dtype
153    #[must_use]
154    fn cast<T: Scalar>(self) -> T {
155        use core::mem::transmute_copy as t;
156        unsafe {
157            match Self::dtype() {
158                DType::BF16 => T::from_bf16(t(&self)),
159                DType::F16 => T::from_f16(t(&self)),
160                DType::F32 => T::from_f32(t(&self)),
161                DType::F64 => T::from_f64(t(&self)),
162                DType::U8 => T::from_u8(t(&self)),
163                DType::U16 => T::from_u16(t(&self)),
164                DType::U32 => T::from_u32(t(&self)),
165                DType::U64 => T::from_u64(t(&self)),
166                DType::I8 => T::from_i8(t(&self)),
167                DType::I16 => T::from_i16(t(&self)),
168                DType::I32 => T::from_i32(t(&self)),
169                DType::I64 => T::from_i64(t(&self)),
170                DType::Bool => T::from_bool(t(&self)),
171            }
172        }
173    }
174    /// Very small value of scalar, very close to zero, zero in case of integers
175    #[must_use]
176    fn epsilon() -> Self {
177        Self::zero()
178    }
179}
180
181/// Float dtype
182pub trait Float: Scalar {
183    /// Round down
184    #[must_use]
185    fn floor(self) -> Self;
186    /// 1/self
187    #[must_use]
188    fn reciprocal(self) -> Self;
189    /// Sin
190    #[must_use]
191    fn sin(self) -> Self;
192    /// Cos
193    #[must_use]
194    fn cos(self) -> Self;
195    /// Square root of this scalar.
196    #[must_use]
197    fn sqrt(self) -> Self;
198    /// Truncate towards zero
199    #[must_use]
200    fn trunc(self) -> Self;
201}
202
203impl Scalar for bf16 {
204    fn from_bf16(t: bf16) -> Self {
205        t
206    }
207
208    fn from_f16(t: f16) -> Self {
209        bf16::from_f32(t.into())
210    }
211
212    fn from_f32(t: f32) -> Self {
213        bf16::from_f32(t)
214    }
215
216    fn from_f64(t: f64) -> Self {
217        bf16::from_f64(t)
218    }
219
220    fn from_u8(t: u8) -> Self {
221        bf16::from_f32(f32::from(t))
222    }
223
224    fn from_u16(t: u16) -> Self {
225        bf16::from_f32(f32::from(t))
226    }
227
228    fn from_u32(t: u32) -> Self {
229        bf16::from_f64(f64::from(t))
230    }
231
232    #[allow(clippy::cast_precision_loss)]
233    fn from_u64(t: u64) -> Self {
234        bf16::from_f64(t as f64)
235    }
236
237    fn from_i8(t: i8) -> Self {
238        bf16::from_f32(f32::from(t))
239    }
240
241    fn from_i16(t: i16) -> Self {
242        bf16::from_f32(f32::from(t))
243    }
244
245    #[allow(clippy::cast_possible_truncation)]
246    fn from_i32(t: i32) -> Self {
247        bf16::from_f32(t as f32)
248    }
249
250    #[allow(clippy::cast_possible_truncation)]
251    fn from_i64(t: i64) -> Self {
252        bf16::from_f32(t as f32)
253    }
254
255    fn from_bool(t: bool) -> Self {
256        bf16::from_f32(f32::from(t))
257    }
258
259    fn from_le_bytes(bytes: &[u8]) -> Self {
260        bf16::from_le_bytes([bytes[0], bytes[1]])
261    }
262
263    fn to_ne_bytes(&self) -> &[u8] {
264        let x: *const Self = self;
265        unsafe { std::slice::from_raw_parts(x.cast(), 2) }
266    }
267
268    fn dtype() -> DType {
269        DType::BF16
270    }
271
272    fn zero() -> Self {
273        bf16::ZERO
274    }
275
276    fn one() -> Self {
277        bf16::ONE
278    }
279
280    fn abs(self) -> Self {
281        self.max(-self)
282    }
283
284    fn neg(self) -> Self {
285        -self
286    }
287
288    fn exp2(self) -> Self {
289        bf16::from_f64(f64::from(self).exp2())
290    }
291
292    fn log2(self) -> Self {
293        bf16::from_f64(f64::from(self).log2())
294    }
295
296    fn relu(self) -> Self {
297        Scalar::max(self, Self::ZERO)
298    }
299
300    fn not(self) -> Self {
301        bf16::from_f32(if f64::from(self) == 0.0 { 0.0 } else { 1.0 })
302    }
303
304    fn nonzero(self) -> Self {
305        bf16::from_f32(if f64::from(self) == 0.0 { 0.0 } else { 1.0 })
306    }
307
308    fn add(self, rhs: Self) -> Self {
309        self + rhs
310    }
311
312    fn sub(self, rhs: Self) -> Self {
313        self - rhs
314    }
315
316    fn mul(self, rhs: Self) -> Self {
317        self * rhs
318    }
319
320    fn div(self, rhs: Self) -> Self {
321        self / rhs
322    }
323
324    fn pow(self, rhs: Self) -> Self {
325        bf16::from_f64(f64::from(self).powf(f64::from(rhs)))
326    }
327
328    fn mod_(self, rhs: Self) -> Self {
329        self % rhs
330    }
331
332    fn cmplt(self, rhs: Self) -> bool {
333        self < rhs
334    }
335
336    fn cmpgt(self, rhs: Self) -> bool {
337        self > rhs
338    }
339
340    fn noteq(self, rhs: Self) -> bool {
341        self != rhs
342    }
343
344    fn or(self, rhs: Self) -> bool {
345        self != Self::ZERO || rhs != Self::ZERO
346    }
347
348    fn bitxor(self, rhs: Self) -> Self {
349        let a = f64::from(self) as i64;
350        let b = f64::from(rhs) as i64;
351        bf16::from_f32((a ^ b) as f32)
352    }
353
354    fn bitor(self, rhs: Self) -> Self {
355        let a = f64::from(self) as i64;
356        let b = f64::from(rhs) as i64;
357        bf16::from_f32((a | b) as f32)
358    }
359
360    fn bitand(self, rhs: Self) -> Self {
361        let a = f64::from(self) as i64;
362        let b = f64::from(rhs) as i64;
363        bf16::from_f32((a & b) as f32)
364    }
365
366    fn bitshiftleft(self, rhs: Self) -> Self {
367        let a = f64::from(self) as i64;
368        let b = f64::from(rhs) as i64;
369        bf16::from_f32((a << b) as f32)
370    }
371
372    fn bitshiftright(self, rhs: Self) -> Self {
373        let a = f64::from(self) as i64;
374        let b = f64::from(rhs) as i64;
375        bf16::from_f32((a >> b) as f32)
376    }
377
378    fn and(self, rhs: Self) -> bool {
379        self != Self::ZERO && rhs != Self::ZERO
380    }
381
382    fn max(self, rhs: Self) -> Self {
383        self.max(rhs)
384    }
385
386    fn max_value() -> Self {
387        bf16::MAX
388    }
389
390    fn min_value() -> Self {
391        bf16::MIN
392    }
393
394    fn is_equal(self, rhs: Self) -> bool {
395        self == rhs
396    }
397
398    fn epsilon() -> Self {
399        bf16::MIN_POSITIVE
400    }
401}
402
403impl Float for bf16 {
404    fn reciprocal(self) -> Self {
405        bf16::ONE / self
406    }
407
408    fn floor(self) -> Self {
409        bf16::from_f32(self.to_f32().floor())
410    }
411
412    fn sin(self) -> Self {
413        bf16::from_f32(self.to_f32().sin())
414    }
415
416    fn cos(self) -> Self {
417        bf16::from_f32(self.to_f32().cos())
418    }
419
420    fn sqrt(self) -> Self {
421        bf16::from_f32(self.to_f32().sqrt())
422    }
423
424    fn trunc(self) -> Self {
425        bf16::from_f32(self.to_f32().trunc())
426    }
427}
428
429impl Scalar for f16 {
430    fn from_bf16(t: bf16) -> Self {
431        f16::from_f32(t.to_f32())
432    }
433
434    fn from_f16(t: f16) -> Self {
435        f16::from_f32(t.to_f32())
436    }
437
438    fn from_f32(t: f32) -> Self {
439        f16::from_f32(t)
440    }
441
442    fn from_f64(t: f64) -> Self {
443        f16::from_f64(t)
444    }
445
446    fn from_u8(t: u8) -> Self {
447        f16::from_f32(t as f32)
448    }
449
450    fn from_u16(t: u16) -> Self {
451        f16::from_f32(t as f32)
452    }
453
454    fn from_u32(t: u32) -> Self {
455        f16::from_f64(t.into())
456    }
457
458    fn from_u64(t: u64) -> Self {
459        f16::from_f64(t as f64)
460    }
461
462    fn from_i8(t: i8) -> Self {
463        f16::from_f32(t as f32)
464    }
465
466    fn from_i16(t: i16) -> Self {
467        f16::from_f32(t as f32)
468    }
469
470    #[allow(clippy::cast_lossless)]
471    fn from_i32(t: i32) -> Self {
472        f16::from_f64(t as f64)
473    }
474
475    #[allow(clippy::cast_precision_loss)]
476    fn from_i64(t: i64) -> Self {
477        f16::from_f64(t as f64)
478    }
479
480    #[allow(clippy::cast_lossless)]
481    fn from_bool(t: bool) -> Self {
482        f16::from_f64(t as i8 as f64)
483    }
484
485    fn from_le_bytes(bytes: &[u8]) -> Self {
486        f16::from_le_bytes([bytes[0], bytes[1]])
487    }
488
489    fn to_ne_bytes(&self) -> &[u8] {
490        let i: *const Self = self;
491        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
492    }
493
494    fn dtype() -> DType {
495        DType::F16
496    }
497
498    fn zero() -> Self {
499        f16::ZERO
500    }
501
502    fn one() -> Self {
503        f16::ONE
504    }
505
506    fn abs(self) -> Self {
507        self.max(-self)
508    }
509
510    fn neg(self) -> Self {
511        -self
512    }
513
514    fn exp2(self) -> Self {
515        f16::from_f32(self.to_f32().exp2())
516    }
517
518    fn log2(self) -> Self {
519        f16::from_f32(self.to_f32().log2())
520    }
521
522    fn relu(self) -> Self {
523        self.max(f16::ZERO)
524    }
525
526    fn not(self) -> Self {
527        f16::from_f32(if f32::from(self) == 0.0 { 0.0 } else { 1.0 })
528    }
529
530    fn nonzero(self) -> Self {
531        f16::from_f32(if f32::from(self) == 0.0 { 0.0 } else { 1.0 })
532    }
533
534    fn add(self, rhs: Self) -> Self {
535        self + rhs
536    }
537
538    fn sub(self, rhs: Self) -> Self {
539        self - rhs
540    }
541
542    fn mul(self, rhs: Self) -> Self {
543        self * rhs
544    }
545
546    fn div(self, rhs: Self) -> Self {
547        self / rhs
548    }
549
550    fn pow(self, rhs: Self) -> Self {
551        f16::from_f32(self.to_f32().pow(rhs.to_f32()))
552    }
553
554    fn mod_(self, rhs: Self) -> Self {
555        self % rhs
556    }
557
558    fn cmplt(self, rhs: Self) -> bool {
559        self < rhs
560    }
561
562    fn cmpgt(self, rhs: Self) -> bool {
563        self > rhs
564    }
565
566    fn noteq(self, rhs: Self) -> bool {
567        self != rhs
568    }
569
570    fn or(self, rhs: Self) -> bool {
571        self != f16::ZERO || rhs != f16::ZERO
572    }
573
574    fn bitxor(self, rhs: Self) -> Self {
575        let ix = self.to_bits() ^ rhs.to_bits();
576        f16::from_le_bytes([ix as u8, (ix >> 8) as u8])
577    }
578
579    fn bitor(self, rhs: Self) -> Self {
580        let ix = self.to_bits() | rhs.to_bits();
581        f16::from_le_bytes([ix as u8, (ix >> 8) as u8])
582    }
583
584    fn bitand(self, rhs: Self) -> Self {
585        let ix = self.to_bits() & rhs.to_bits();
586        f16::from_le_bytes([ix as u8, (ix >> 8) as u8])
587    }
588
589    fn bitshiftleft(self, rhs: Self) -> Self {
590        let lhs_f32 = self.to_f32();
591        let rhs_f32 = rhs.to_f32();
592        let lhs_bits = lhs_f32.to_bits() as i32;
593        let rhs_bits = rhs_f32.to_bits() as i32;
594        let result = f32::from_bits((lhs_bits << rhs_bits) as u32);
595        f16::from_f32(result)
596    }
597
598    fn bitshiftright(self, rhs: Self) -> Self {
599        let lhs_f32 = self.to_f32();
600        let rhs_f32 = rhs.to_f32();
601        let lhs_bits = lhs_f32.to_bits() as i32;
602        let rhs_bits = rhs_f32.to_bits() as i32;
603        let result = f32::from_bits((lhs_bits >> rhs_bits) as u32);
604        f16::from_f32(result)
605    }
606
607    fn and(self, rhs: Self) -> bool {
608        self != f16::ZERO && rhs != f16::ZERO
609    }
610
611    fn max(self, rhs: Self) -> Self {
612        f16::max(self, rhs)
613    }
614
615    fn max_value() -> Self {
616        f16::MAX
617    }
618
619    fn min_value() -> Self {
620        f16::MIN
621    }
622
623    fn is_equal(self, rhs: Self) -> bool {
624        (self == -Self::INFINITY && rhs == -Self::INFINITY)
625            || (self.is_nan() && rhs.is_nan())
626            || self.sub(rhs).abs() < self.abs() * f16::from_f32(0.0001)
627    }
628
629    fn epsilon() -> Self {
630        f16::from_f32(0.00001)
631    }
632}
633
634impl Float for f16 {
635    fn reciprocal(self) -> Self {
636        f16::ONE / self
637    }
638
639    fn sin(self) -> Self {
640        f16::from_f32(self.to_f32().sin())
641    }
642
643    fn cos(self) -> Self {
644        f16::from_f32(self.to_f32().cos())
645    }
646
647    fn sqrt(self) -> Self {
648        f16::from_f32(self.to_f32().sqrt())
649    }
650
651    fn floor(self) -> Self {
652        f16::from_f32(self.to_f32().floor())
653    }
654
655    fn trunc(self) -> Self {
656        f16::from_f32(self.to_f32().trunc())
657    }
658}
659
660impl Scalar for f32 {
661    fn from_bf16(t: bf16) -> Self {
662        t.into()
663    }
664
665    fn from_f16(t: f16) -> Self {
666        t.into()
667    }
668
669    fn from_f32(t: f32) -> Self {
670        t
671    }
672
673    #[allow(clippy::cast_possible_truncation)]
674    fn from_f64(t: f64) -> Self {
675        t as Self
676    }
677
678    fn from_u8(t: u8) -> Self {
679        f32::from(t)
680    }
681
682    fn from_u16(t: u16) -> Self {
683        t.into()
684    }
685
686    #[allow(clippy::cast_precision_loss)]
687    fn from_u32(t: u32) -> Self {
688        t as f32
689    }
690
691    #[allow(clippy::cast_precision_loss)]
692    fn from_u64(t: u64) -> Self {
693        t as f32
694    }
695
696    fn from_i8(t: i8) -> Self {
697        f32::from(t)
698    }
699
700    fn from_i16(t: i16) -> Self {
701        f32::from(t)
702    }
703
704    #[allow(clippy::cast_precision_loss)]
705    fn from_i32(t: i32) -> Self {
706        t as f32
707    }
708
709    #[allow(clippy::cast_precision_loss)]
710    fn from_i64(t: i64) -> Self {
711        t as f32
712    }
713
714    fn from_bool(t: bool) -> Self {
715        f32::from(i8::from(t))
716    }
717
718    fn from_le_bytes(bytes: &[u8]) -> Self {
719        f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
720    }
721
722    fn to_ne_bytes(&self) -> &[u8] {
723        let i: *const Self = self;
724        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
725    }
726
727    fn dtype() -> DType {
728        DType::F32
729    }
730
731    fn zero() -> Self {
732        0.
733    }
734
735    fn one() -> Self {
736        1.
737    }
738
739    fn abs(self) -> Self {
740        self.abs()
741    }
742
743    fn neg(self) -> Self {
744        -self
745    }
746
747    fn exp2(self) -> Self {
748        f32::exp2(self)
749    }
750
751    fn log2(self) -> Self {
752        self.log2()
753    }
754
755    fn relu(self) -> Self {
756        self.max(0.)
757    }
758
759    fn not(self) -> Self {
760        if self == 0. { 1. } else { 0. }
761    }
762
763    fn nonzero(self) -> Self {
764        f32::from(i8::from(self != 0.))
765    }
766
767    fn add(self, rhs: Self) -> Self {
768        self + rhs
769    }
770
771    fn sub(self, rhs: Self) -> Self {
772        self - rhs
773    }
774
775    fn mul(self, rhs: Self) -> Self {
776        self * rhs
777    }
778
779    fn div(self, rhs: Self) -> Self {
780        self / rhs
781    }
782
783    fn pow(self, rhs: Self) -> Self {
784        self.powf(rhs)
785    }
786
787    fn mod_(self, rhs: Self) -> Self {
788        self % rhs
789    }
790
791    fn cmplt(self, rhs: Self) -> bool {
792        self < rhs
793    }
794
795    fn cmpgt(self, rhs: Self) -> bool {
796        self > rhs
797    }
798
799    fn noteq(self, rhs: Self) -> bool {
800        !self.is_equal(rhs)
801    }
802
803    fn or(self, rhs: Self) -> bool {
804        self != 0. || rhs != 0.
805    }
806
807    fn bitxor(self, rhs: Self) -> Self {
808        let rhs_bits = rhs.to_bits();
809        f32::from_bits(self.to_bits() ^ rhs_bits)
810    }
811
812    fn bitor(self, rhs: Self) -> Self {
813        let rhs_bits = rhs.to_bits();
814        f32::from_bits(self.to_bits() | rhs_bits)
815    }
816
817    fn bitand(self, rhs: Self) -> Self {
818        let rhs_bits = rhs.to_bits();
819        f32::from_bits(self.to_bits() & rhs_bits)
820    }
821
822    fn bitshiftleft(self, rhs: Self) -> Self {
823        let rhs_shift = (rhs.to_bits() & 0xFF) as u32;
824        let ix = (self.to_bits() as u64) << rhs_shift;
825        f32::from_bits(ix as u32)
826    }
827
828    fn bitshiftright(self, rhs: Self) -> Self {
829        let rhs_shift = (rhs.to_bits() & 0xFF) as u32;
830        let ix = (self.to_bits() as u32) >> rhs_shift;
831        f32::from_bits(ix)
832    }
833
834    fn and(self, rhs: Self) -> bool {
835        self != 0. && rhs != 0.
836    }
837
838    fn max(self, rhs: Self) -> Self {
839        f32::max(self, rhs)
840    }
841
842    fn max_value() -> Self {
843        f32::MAX
844    }
845
846    fn min_value() -> Self {
847        f32::MIN
848    }
849
850    fn is_equal(self, rhs: Self) -> bool {
851        let a = self;
852        let b = rhs;
853        if a.is_nan() && b.is_nan() {
854            return true;
855        }
856        #[allow(clippy::float_cmp)]
857        if a == b {
858            return true;
859        }
860        let diff = (a - b).abs();
861        let max_abs = a.abs().max(b.abs());
862        let rel_tol = 1e-3 * max_abs; // relative tolerance for large numbers
863        let abs_tol = 2e-7; // absolute tolerance for tiny numbers
864        diff < rel_tol || diff < abs_tol
865    }
866
867    fn epsilon() -> Self {
868        0.0001
869    }
870}
871
872impl Float for f32 {
873    fn sin(self) -> Self {
874        //libm::sinf(self)
875        //let b = 4f32 / PI;
876        //let c = -4f32 / (PI * PI);
877        //return -(b * self + c * self * if self < 0. { -self } else { self });
878        f32::sin(self)
879    }
880
881    fn floor(self) -> Self {
882        f32::floor(self)
883    }
884
885    fn cos(self) -> Self {
886        //libm::cosf(self)
887        //let mut x = self;
888        //x *= 1. / (2. * PI);
889        //x -= 0.25 + (x + 0.25).floor();
890        //x *= 16.0 * (x.abs() - 0.5);
891        //x += 0.225 * x * (x.abs() - 1.0);
892        //return x;
893        f32::cos(self)
894    }
895
896    fn sqrt(self) -> Self {
897        // good enough (error of ~ 5%)
898        /*if self >= 0. {
899            Self::from_bits((self.to_bits() + 0x3f80_0000) >> 1)
900        } else {
901            Self::NAN
902        }*/
903        f32::sqrt(self)
904    }
905
906    fn reciprocal(self) -> Self {
907        1.0 / self
908    }
909
910    fn trunc(self) -> Self {
911        f32::trunc(self)
912    }
913}
914
915impl Scalar for f64 {
916    fn from_bf16(t: bf16) -> Self {
917        t.into()
918    }
919
920    fn from_f16(t: f16) -> Self {
921        t.into()
922    }
923
924    fn from_f32(t: f32) -> Self {
925        f64::from(t)
926    }
927
928    fn from_f64(t: f64) -> Self {
929        t
930    }
931
932    fn from_u8(t: u8) -> Self {
933        f64::from(t)
934    }
935
936    fn from_u16(t: u16) -> Self {
937        t.into()
938    }
939
940    fn from_u32(t: u32) -> Self {
941        t.into()
942    }
943
944    #[allow(clippy::cast_precision_loss)]
945    fn from_u64(t: u64) -> Self {
946        t as f64
947    }
948
949    fn from_i8(t: i8) -> Self {
950        t.into()
951    }
952
953    fn from_i16(t: i16) -> Self {
954        t.into()
955    }
956
957    fn from_i32(t: i32) -> Self {
958        t.into()
959    }
960
961    #[allow(clippy::cast_precision_loss)]
962    fn from_i64(t: i64) -> Self {
963        t as f64
964    }
965
966    fn from_bool(t: bool) -> Self {
967        t.into()
968    }
969
970    fn from_le_bytes(bytes: &[u8]) -> Self {
971        f64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
972    }
973
974    fn to_ne_bytes(&self) -> &[u8] {
975        let i: *const Self = self;
976        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
977    }
978
979    fn dtype() -> DType {
980        DType::F64
981    }
982
983    fn zero() -> Self {
984        0.
985    }
986
987    fn one() -> Self {
988        1.
989    }
990
991    fn abs(self) -> Self {
992        self.abs()
993    }
994
995    fn neg(self) -> Self {
996        -self
997    }
998
999    fn exp2(self) -> Self {
1000        f64::exp2(self)
1001    }
1002
1003    fn log2(self) -> Self {
1004        self.log2()
1005    }
1006
1007    fn relu(self) -> Self {
1008        self.max(0.)
1009    }
1010
1011    fn not(self) -> Self {
1012        if self == 0. { 1. } else { 0. }
1013    }
1014
1015    fn nonzero(self) -> Self {
1016        u8::from(self != 0.).into()
1017    }
1018
1019    fn add(self, rhs: Self) -> Self {
1020        self + rhs
1021    }
1022
1023    fn sub(self, rhs: Self) -> Self {
1024        self - rhs
1025    }
1026
1027    fn mul(self, rhs: Self) -> Self {
1028        self * rhs
1029    }
1030
1031    fn div(self, rhs: Self) -> Self {
1032        self / rhs
1033    }
1034
1035    fn pow(self, rhs: Self) -> Self {
1036        self.powf(rhs)
1037    }
1038
1039    fn mod_(self, rhs: Self) -> Self {
1040        self % rhs
1041    }
1042
1043    fn cmplt(self, rhs: Self) -> bool {
1044        self < rhs
1045    }
1046
1047    fn cmpgt(self, rhs: Self) -> bool {
1048        self > rhs
1049    }
1050
1051    fn noteq(self, rhs: Self) -> bool {
1052        !self.is_equal(rhs)
1053    }
1054
1055    fn or(self, rhs: Self) -> bool {
1056        self != 0. || rhs != 0.
1057    }
1058
1059    fn bitxor(self, rhs: Self) -> Self {
1060        f64::from_bits(self.to_bits() ^ rhs.to_bits())
1061    }
1062
1063    fn bitor(self, rhs: Self) -> Self {
1064        f64::from_bits(self.to_bits() | rhs.to_bits())
1065    }
1066
1067    fn bitand(self, rhs: Self) -> Self {
1068        f64::from_bits(self.to_bits() & rhs.to_bits())
1069    }
1070
1071    fn bitshiftleft(self, rhs: Self) -> Self {
1072        let rhs_shift = (rhs.to_bits() & 0xFF) as u32;
1073        let ix = (self.to_bits() as u64) << rhs_shift;
1074        f64::from_bits(ix)
1075    }
1076
1077    fn bitshiftright(self, rhs: Self) -> Self {
1078        let rhs_shift = (rhs.to_bits() & 0xFF) as i32;
1079        let ix = ((self.to_bits() as u64) >> rhs_shift) as u32;
1080        f64::from_bits(ix as u64)
1081    }
1082
1083    fn and(self, rhs: Self) -> bool {
1084        self != 0. && rhs != 0.
1085    }
1086
1087    fn max(self, rhs: Self) -> Self {
1088        f64::max(self, rhs)
1089    }
1090
1091    fn max_value() -> Self {
1092        f64::MAX
1093    }
1094
1095    fn min_value() -> Self {
1096        f64::MIN
1097    }
1098
1099    fn is_equal(self, rhs: Self) -> bool {
1100        // Less than 0.1% error is OK
1101        (self == -f64::INFINITY && rhs == -f64::INFINITY) || (self - rhs).abs() <= self.abs() * 0.001
1102    }
1103
1104    fn epsilon() -> Self {
1105        0.00001
1106    }
1107}
1108
1109impl Float for f64 {
1110    fn reciprocal(self) -> Self {
1111        1.0 / self
1112    }
1113
1114    fn floor(self) -> Self {
1115        self.floor()
1116    }
1117
1118    fn sin(self) -> Self {
1119        f64::sin(self)
1120    }
1121
1122    fn cos(self) -> Self {
1123        f64::cos(self)
1124    }
1125
1126    fn sqrt(self) -> Self {
1127        f64::sqrt(self)
1128    }
1129
1130    fn trunc(self) -> Self {
1131        self.trunc()
1132    }
1133}
1134
1135impl Scalar for i8 {
1136    #[allow(clippy::cast_possible_truncation)]
1137    fn from_bf16(t: bf16) -> Self {
1138        let t: f32 = t.into();
1139        t as Self
1140    }
1141
1142    #[allow(clippy::cast_possible_truncation)]
1143    fn from_f16(t: f16) -> Self {
1144        let t: f32 = t.into();
1145        t as Self
1146    }
1147
1148    #[allow(clippy::cast_possible_truncation)]
1149    fn from_f32(t: f32) -> Self {
1150        t as Self
1151    }
1152
1153    #[allow(clippy::cast_possible_truncation)]
1154    fn from_f64(t: f64) -> Self {
1155        t as Self
1156    }
1157
1158    fn from_u8(t: u8) -> Self {
1159        t.try_into().unwrap()
1160    }
1161
1162    fn from_u16(t: u16) -> Self {
1163        t.try_into().unwrap()
1164    }
1165
1166    fn from_u32(t: u32) -> Self {
1167        t.try_into().unwrap()
1168    }
1169
1170    fn from_u64(t: u64) -> Self {
1171        t.try_into().unwrap()
1172    }
1173
1174    fn from_i8(t: i8) -> Self {
1175        t
1176    }
1177
1178    fn from_i16(t: i16) -> Self {
1179        Self::try_from(t).unwrap()
1180    }
1181
1182    fn from_i32(t: i32) -> Self {
1183        Self::try_from(t).unwrap()
1184    }
1185
1186    fn from_i64(t: i64) -> Self {
1187        Self::try_from(t).unwrap()
1188    }
1189
1190    fn from_bool(t: bool) -> Self {
1191        Self::from(t)
1192    }
1193
1194    fn from_le_bytes(bytes: &[u8]) -> Self {
1195        i8::from_le_bytes([bytes[0]])
1196    }
1197
1198    fn to_ne_bytes(&self) -> &[u8] {
1199        let i: *const Self = self;
1200        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1201    }
1202
1203    fn dtype() -> DType {
1204        DType::I8
1205    }
1206
1207    fn zero() -> Self {
1208        0
1209    }
1210
1211    fn one() -> Self {
1212        1
1213    }
1214
1215    fn abs(self) -> Self {
1216        self.abs()
1217    }
1218
1219    fn neg(self) -> Self {
1220        -self
1221    }
1222
1223    fn exp2(self) -> Self {
1224        2i32.pow(self as u32) as i8
1225    }
1226
1227    fn log2(self) -> Self {
1228        f64::from(self).log2() as i8
1229    }
1230
1231    fn relu(self) -> Self {
1232        Scalar::max(self, 0)
1233    }
1234
1235    fn not(self) -> Self {
1236        i8::from(self == 0)
1237    }
1238
1239    fn nonzero(self) -> Self {
1240        i8::from(self != 0)
1241    }
1242
1243    fn add(self, rhs: Self) -> Self {
1244        self + rhs
1245    }
1246
1247    fn sub(self, rhs: Self) -> Self {
1248        self - rhs
1249    }
1250
1251    fn mul(self, rhs: Self) -> Self {
1252        self * rhs
1253    }
1254
1255    fn div(self, rhs: Self) -> Self {
1256        self / rhs
1257    }
1258
1259    fn pow(self, rhs: Self) -> Self {
1260        if rhs >= 0 {
1261            return self.pow(rhs as u32);
1262        }
1263        if self == 1 {
1264            return 1;
1265        }
1266        if self == -1 {
1267            return if rhs % 2 == 0 { 1 } else { -1 };
1268        }
1269        0
1270    }
1271
1272    fn mod_(self, rhs: Self) -> Self {
1273        self % rhs
1274    }
1275
1276    fn cmplt(self, rhs: Self) -> bool {
1277        self < rhs
1278    }
1279
1280    fn cmpgt(self, rhs: Self) -> bool {
1281        self > rhs
1282    }
1283
1284    fn noteq(self, rhs: Self) -> bool {
1285        self != rhs
1286    }
1287
1288    fn or(self, rhs: Self) -> bool {
1289        self != 0 || rhs != 0
1290    }
1291
1292    fn bitxor(self, rhs: Self) -> Self {
1293        self ^ rhs
1294    }
1295
1296    fn bitor(self, rhs: Self) -> Self {
1297        self | rhs
1298    }
1299
1300    fn bitand(self, rhs: Self) -> Self {
1301        self & rhs
1302    }
1303
1304    fn bitshiftleft(self, rhs: Self) -> Self {
1305        self.wrapping_shl(rhs as u32)
1306    }
1307
1308    fn bitshiftright(self, rhs: Self) -> Self {
1309        self.wrapping_shr(rhs as u32)
1310    }
1311
1312    fn and(self, rhs: Self) -> bool {
1313        self != 0 && rhs != 0
1314    }
1315
1316    fn max(self, rhs: Self) -> Self {
1317        <i8 as Ord>::max(self, rhs)
1318    }
1319
1320    fn max_value() -> Self {
1321        i8::MAX
1322    }
1323
1324    fn min_value() -> Self {
1325        i8::MIN
1326    }
1327
1328    fn is_equal(self, rhs: Self) -> bool {
1329        self == rhs
1330    }
1331
1332    fn epsilon() -> Self {
1333        0
1334    }
1335}
1336
1337impl Scalar for i16 {
1338    #[allow(clippy::cast_possible_truncation)]
1339    fn from_bf16(t: bf16) -> Self {
1340        t.to_f32() as i16
1341    }
1342
1343    #[allow(clippy::cast_possible_truncation)]
1344    fn from_f16(t: f16) -> Self {
1345        t.to_f32() as i16
1346    }
1347
1348    #[allow(clippy::cast_possible_truncation)]
1349    fn from_f32(t: f32) -> Self {
1350        t as i16
1351    }
1352
1353    #[allow(clippy::cast_possible_truncation)]
1354    fn from_f64(t: f64) -> Self {
1355        t as i16
1356    }
1357
1358    fn from_u8(t: u8) -> Self {
1359        t.into()
1360    }
1361
1362    #[allow(clippy::cast_possible_truncation)]
1363    #[allow(clippy::cast_possible_wrap)]
1364    fn from_u16(t: u16) -> Self {
1365        t as i16
1366    }
1367
1368    #[allow(clippy::cast_possible_truncation)]
1369    fn from_u32(t: u32) -> Self {
1370        t as i16
1371    }
1372
1373    fn from_u64(t: u64) -> Self {
1374        t.try_into().unwrap()
1375    }
1376
1377    fn from_i8(t: i8) -> Self {
1378        t.into()
1379    }
1380
1381    fn from_i16(t: i16) -> Self {
1382        t
1383    }
1384
1385    #[allow(clippy::cast_possible_truncation)]
1386    fn from_i32(t: i32) -> Self {
1387        t as i16
1388    }
1389
1390    #[allow(clippy::cast_possible_truncation)]
1391    fn from_i64(t: i64) -> Self {
1392        t as i16
1393    }
1394
1395    fn from_bool(t: bool) -> Self {
1396        t.into()
1397    }
1398
1399    fn from_le_bytes(bytes: &[u8]) -> Self {
1400        i16::from_le_bytes([bytes[0], bytes[1]])
1401    }
1402
1403    fn to_ne_bytes(&self) -> &[u8] {
1404        let i: *const Self = self;
1405        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1406    }
1407
1408    fn dtype() -> DType {
1409        DType::I16
1410    }
1411
1412    fn zero() -> Self {
1413        0
1414    }
1415
1416    fn one() -> Self {
1417        1
1418    }
1419
1420    fn abs(self) -> Self {
1421        self.abs()
1422    }
1423
1424    fn neg(self) -> Self {
1425        -self
1426    }
1427
1428    fn exp2(self) -> Self {
1429        2i32.pow(self as u32) as i16
1430    }
1431
1432    fn log2(self) -> Self {
1433        f64::from(self).log2() as i16
1434    }
1435
1436    fn relu(self) -> Self {
1437        Scalar::max(self, 0)
1438    }
1439
1440    fn not(self) -> Self {
1441        i16::from(self == 0)
1442    }
1443
1444    fn nonzero(self) -> Self {
1445        i16::from(self != 0)
1446    }
1447
1448    fn add(self, rhs: Self) -> Self {
1449        self + rhs
1450    }
1451
1452    fn sub(self, rhs: Self) -> Self {
1453        self - rhs
1454    }
1455
1456    fn mul(self, rhs: Self) -> Self {
1457        self * rhs
1458    }
1459
1460    fn div(self, rhs: Self) -> Self {
1461        self / rhs
1462    }
1463
1464    fn pow(self, rhs: Self) -> Self {
1465        if rhs >= 0 {
1466            return self.pow(rhs as u32);
1467        }
1468        if self == 1 {
1469            return 1;
1470        }
1471        if self == -1 {
1472            return if rhs % 2 == 0 { 1 } else { -1 };
1473        }
1474        0
1475    }
1476
1477    fn mod_(self, rhs: Self) -> Self {
1478        self % rhs
1479    }
1480
1481    fn cmplt(self, rhs: Self) -> bool {
1482        self < rhs
1483    }
1484
1485    fn cmpgt(self, rhs: Self) -> bool {
1486        self > rhs
1487    }
1488
1489    fn noteq(self, rhs: Self) -> bool {
1490        self != rhs
1491    }
1492
1493    fn or(self, rhs: Self) -> bool {
1494        self != 0 || rhs != 0
1495    }
1496
1497    fn bitxor(self, rhs: Self) -> Self {
1498        self ^ rhs
1499    }
1500
1501    fn bitor(self, rhs: Self) -> Self {
1502        self | rhs
1503    }
1504
1505    fn bitand(self, rhs: Self) -> Self {
1506        self & rhs
1507    }
1508
1509    fn bitshiftleft(self, rhs: Self) -> Self {
1510        self.wrapping_shl(rhs as u32)
1511    }
1512
1513    fn bitshiftright(self, rhs: Self) -> Self {
1514        self.wrapping_shr(rhs as u32)
1515    }
1516
1517    fn and(self, rhs: Self) -> bool {
1518        self != 0 && rhs != 0
1519    }
1520
1521    fn max(self, rhs: Self) -> Self {
1522        Ord::max(self, rhs)
1523    }
1524
1525    fn max_value() -> Self {
1526        i16::MAX
1527    }
1528
1529    fn min_value() -> Self {
1530        i16::MIN
1531    }
1532
1533    fn is_equal(self, rhs: Self) -> bool {
1534        self == rhs
1535    }
1536
1537    fn epsilon() -> Self {
1538        0
1539    }
1540}
1541
1542impl Scalar for i32 {
1543    fn from_bf16(t: bf16) -> Self {
1544        t.to_f32() as i32
1545    }
1546
1547    fn from_f16(t: f16) -> Self {
1548        t.to_f32() as i32
1549    }
1550
1551    #[allow(clippy::cast_possible_truncation)]
1552    fn from_f32(t: f32) -> Self {
1553        t as i32
1554    }
1555
1556    #[allow(clippy::cast_possible_truncation)]
1557    fn from_f64(t: f64) -> Self {
1558        t as i32
1559    }
1560
1561    fn from_u8(t: u8) -> Self {
1562        t.into()
1563    }
1564
1565    fn from_u16(t: u16) -> Self {
1566        t.into()
1567    }
1568
1569    fn from_u32(t: u32) -> Self {
1570        i32::try_from(t).unwrap()
1571    }
1572
1573    fn from_u64(t: u64) -> Self {
1574        t.try_into().unwrap()
1575    }
1576
1577    fn from_i8(t: i8) -> Self {
1578        t.into()
1579    }
1580
1581    fn from_i16(t: i16) -> Self {
1582        t.into()
1583    }
1584
1585    fn from_i32(t: i32) -> Self {
1586        t
1587    }
1588
1589    #[allow(clippy::cast_possible_truncation)]
1590    fn from_i64(t: i64) -> Self {
1591        t as i32
1592    }
1593
1594    fn from_bool(t: bool) -> Self {
1595        t.into()
1596    }
1597
1598    fn from_le_bytes(bytes: &[u8]) -> Self {
1599        i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
1600    }
1601
1602    fn to_ne_bytes(&self) -> &[u8] {
1603        let i: *const i32 = self;
1604        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<i32>()) }
1605    }
1606
1607    fn dtype() -> DType {
1608        DType::I32
1609    }
1610
1611    fn zero() -> Self {
1612        0
1613    }
1614
1615    fn one() -> Self {
1616        1
1617    }
1618
1619    fn abs(self) -> Self {
1620        self.abs()
1621    }
1622
1623    fn neg(self) -> Self {
1624        -self
1625    }
1626
1627    fn exp2(self) -> Self {
1628        2i32.pow(self as u32)
1629    }
1630
1631    fn log2(self) -> Self {
1632        f64::from(self).log2() as i32
1633    }
1634
1635    fn relu(self) -> Self {
1636        Scalar::max(self, 0)
1637    }
1638
1639    fn not(self) -> Self {
1640        i32::from(self == 0)
1641    }
1642
1643    fn nonzero(self) -> Self {
1644        i32::from(self != 0)
1645    }
1646
1647    fn add(self, rhs: Self) -> Self {
1648        self + rhs
1649    }
1650
1651    fn sub(self, rhs: Self) -> Self {
1652        self - rhs
1653    }
1654
1655    fn mul(self, rhs: Self) -> Self {
1656        self * rhs
1657    }
1658
1659    fn div(self, rhs: Self) -> Self {
1660        self / rhs
1661    }
1662
1663    fn pow(self, rhs: Self) -> Self {
1664        i32::pow(self, u32::try_from(rhs).unwrap())
1665    }
1666
1667    fn mod_(self, rhs: Self) -> Self {
1668        self % rhs
1669    }
1670
1671    fn cmplt(self, rhs: Self) -> bool {
1672        self < rhs
1673    }
1674
1675    fn cmpgt(self, rhs: Self) -> bool {
1676        self > rhs
1677    }
1678
1679    fn noteq(self, rhs: Self) -> bool {
1680        self != rhs
1681    }
1682
1683    fn or(self, rhs: Self) -> bool {
1684        self != 0 || rhs != 0
1685    }
1686
1687    fn bitxor(self, rhs: Self) -> Self {
1688        self ^ rhs
1689    }
1690
1691    fn bitor(self, rhs: Self) -> Self {
1692        self | rhs
1693    }
1694
1695    fn bitand(self, rhs: Self) -> Self {
1696        self & rhs
1697    }
1698
1699    fn bitshiftleft(self, rhs: Self) -> Self {
1700        self.wrapping_shl(rhs as u32)
1701    }
1702
1703    fn bitshiftright(self, rhs: Self) -> Self {
1704        self.wrapping_shr(rhs as u32)
1705    }
1706
1707    fn and(self, rhs: Self) -> bool {
1708        self != 0 && rhs != 0
1709    }
1710
1711    fn max(self, rhs: Self) -> Self {
1712        <i32 as Ord>::max(self, rhs)
1713    }
1714
1715    fn max_value() -> Self {
1716        i32::MAX
1717    }
1718
1719    fn min_value() -> Self {
1720        i32::MIN
1721    }
1722
1723    fn is_equal(self, rhs: Self) -> bool {
1724        self == rhs
1725    }
1726
1727    fn epsilon() -> Self {
1728        0
1729    }
1730}
1731
1732impl Scalar for i64 {
1733    fn from_bf16(t: bf16) -> Self {
1734        t.to_f32() as i64
1735    }
1736
1737    fn from_f16(t: f16) -> Self {
1738        t.to_f32() as i64
1739    }
1740
1741    #[allow(clippy::cast_possible_truncation)]
1742    fn from_f32(t: f32) -> Self {
1743        t as Self
1744    }
1745
1746    #[allow(clippy::cast_possible_truncation)]
1747    fn from_f64(t: f64) -> Self {
1748        t as Self
1749    }
1750
1751    fn from_u8(t: u8) -> Self {
1752        t.into()
1753    }
1754
1755    fn from_u16(t: u16) -> Self {
1756        t.into()
1757    }
1758
1759    fn from_u32(t: u32) -> Self {
1760        t.into()
1761    }
1762
1763    fn from_u64(t: u64) -> Self {
1764        t.try_into().unwrap()
1765    }
1766
1767    fn from_i8(t: i8) -> Self {
1768        t.into()
1769    }
1770
1771    fn from_i16(t: i16) -> Self {
1772        t.into()
1773    }
1774
1775    fn from_i32(t: i32) -> Self {
1776        t.into()
1777    }
1778
1779    fn from_i64(t: i64) -> Self {
1780        t
1781    }
1782
1783    fn from_bool(t: bool) -> Self {
1784        t.into()
1785    }
1786
1787    fn from_le_bytes(bytes: &[u8]) -> Self {
1788        i64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
1789    }
1790
1791    fn to_ne_bytes(&self) -> &[u8] {
1792        let i: *const Self = self;
1793        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1794    }
1795
1796    fn dtype() -> DType {
1797        DType::I64
1798    }
1799
1800    fn zero() -> Self {
1801        0
1802    }
1803
1804    fn one() -> Self {
1805        1
1806    }
1807
1808    fn abs(self) -> Self {
1809        self.abs()
1810    }
1811
1812    fn neg(self) -> Self {
1813        -self
1814    }
1815
1816    fn exp2(self) -> Self {
1817        2i64.pow(self as u32)
1818    }
1819
1820    fn log2(self) -> Self {
1821        self as f64 as i64
1822    }
1823
1824    fn relu(self) -> Self {
1825        Scalar::max(self, 0)
1826    }
1827
1828    fn not(self) -> Self {
1829        i64::from(self == 0)
1830    }
1831
1832    fn nonzero(self) -> Self {
1833        i64::from(self != 0)
1834    }
1835
1836    fn add(self, rhs: Self) -> Self {
1837        self + rhs
1838    }
1839
1840    fn sub(self, rhs: Self) -> Self {
1841        self - rhs
1842    }
1843
1844    fn mul(self, rhs: Self) -> Self {
1845        self * rhs
1846    }
1847
1848    fn div(self, rhs: Self) -> Self {
1849        self / rhs
1850    }
1851
1852    fn pow(self, rhs: Self) -> Self {
1853        i64::pow(self, u32::try_from(rhs).unwrap())
1854    }
1855
1856    fn mod_(self, rhs: Self) -> Self {
1857        self % rhs
1858    }
1859
1860    fn cmplt(self, rhs: Self) -> bool {
1861        self < rhs
1862    }
1863
1864    fn cmpgt(self, rhs: Self) -> bool {
1865        self > rhs
1866    }
1867
1868    fn noteq(self, rhs: Self) -> bool {
1869        self != rhs
1870    }
1871
1872    fn or(self, rhs: Self) -> bool {
1873        self != 0 || rhs != 0
1874    }
1875
1876    fn bitxor(self, rhs: Self) -> Self {
1877        self ^ rhs
1878    }
1879
1880    fn bitor(self, rhs: Self) -> Self {
1881        self | rhs
1882    }
1883
1884    fn bitand(self, rhs: Self) -> Self {
1885        self & rhs
1886    }
1887
1888    fn bitshiftleft(self, rhs: Self) -> Self {
1889        self.wrapping_shl(rhs as u32)
1890    }
1891
1892    fn bitshiftright(self, rhs: Self) -> Self {
1893        self.wrapping_shr(rhs as u32)
1894    }
1895
1896    fn and(self, rhs: Self) -> bool {
1897        self != 0 && rhs != 0
1898    }
1899
1900    fn max(self, rhs: Self) -> Self {
1901        <i64 as Ord>::max(self, rhs)
1902    }
1903
1904    fn max_value() -> Self {
1905        Self::MAX
1906    }
1907
1908    fn min_value() -> Self {
1909        Self::MIN
1910    }
1911
1912    fn is_equal(self, rhs: Self) -> bool {
1913        self == rhs
1914    }
1915
1916    fn epsilon() -> Self {
1917        0
1918    }
1919}
1920
1921impl Scalar for u8 {
1922    fn from_bf16(t: bf16) -> Self {
1923        t.to_f32() as u32 as u8
1924    }
1925
1926    fn from_f16(t: f16) -> Self {
1927        t.to_f32() as u32 as u8
1928    }
1929
1930    fn from_f32(t: f32) -> Self {
1931        t as u32 as u8
1932    }
1933
1934    fn from_f64(t: f64) -> Self {
1935        t as u32 as u8
1936    }
1937
1938    fn from_u8(t: u8) -> Self {
1939        t
1940    }
1941
1942    fn from_u16(t: u16) -> Self {
1943        t.try_into().unwrap()
1944    }
1945
1946    fn from_u32(t: u32) -> Self {
1947        t.try_into().unwrap()
1948    }
1949
1950    fn from_u64(t: u64) -> Self {
1951        t.try_into().unwrap()
1952    }
1953
1954    fn from_i8(t: i8) -> Self {
1955        t.try_into().unwrap()
1956    }
1957
1958    fn from_i16(t: i16) -> Self {
1959        t.try_into().unwrap()
1960    }
1961
1962    fn from_i32(t: i32) -> Self {
1963        t.try_into().unwrap()
1964    }
1965
1966    fn from_i64(t: i64) -> Self {
1967        t.try_into().unwrap()
1968    }
1969
1970    fn from_bool(t: bool) -> Self {
1971        t.into()
1972    }
1973
1974    fn from_le_bytes(bytes: &[u8]) -> Self {
1975        u8::from_le_bytes([bytes[0]])
1976    }
1977
1978    fn to_ne_bytes(&self) -> &[u8] {
1979        let i: *const Self = self;
1980        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1981    }
1982
1983    fn dtype() -> DType {
1984        DType::U8
1985    }
1986
1987    fn zero() -> Self {
1988        0
1989    }
1990
1991    fn one() -> Self {
1992        1
1993    }
1994
1995    fn abs(self) -> Self {
1996        self
1997    }
1998
1999    fn neg(self) -> Self {
2000        self.wrapping_neg()
2001    }
2002
2003    fn exp2(self) -> Self {
2004        if self <= 31 { 2u32.pow(self as u32) as u8 } else { 255 }
2005    }
2006
2007    fn log2(self) -> Self {
2008        self.ilog2() as u8
2009    }
2010
2011    fn relu(self) -> Self {
2012        self
2013    }
2014
2015    fn not(self) -> Self {
2016        u8::from(self == 0)
2017    }
2018
2019    fn nonzero(self) -> Self {
2020        u8::from(self != 0)
2021    }
2022
2023    fn add(self, rhs: Self) -> Self {
2024        self + rhs
2025    }
2026
2027    fn sub(self, rhs: Self) -> Self {
2028        self - rhs
2029    }
2030
2031    fn mul(self, rhs: Self) -> Self {
2032        self * rhs
2033    }
2034
2035    fn div(self, rhs: Self) -> Self {
2036        self / rhs
2037    }
2038
2039    fn pow(self, rhs: Self) -> Self {
2040        Self::pow(self, u32::from(rhs))
2041    }
2042
2043    fn mod_(self, rhs: Self) -> Self {
2044        self % rhs
2045    }
2046
2047    fn cmplt(self, rhs: Self) -> bool {
2048        self < rhs
2049    }
2050
2051    fn cmpgt(self, rhs: Self) -> bool {
2052        self > rhs
2053    }
2054
2055    fn noteq(self, rhs: Self) -> bool {
2056        self != rhs
2057    }
2058
2059    fn or(self, rhs: Self) -> bool {
2060        self != 0 || rhs != 0
2061    }
2062
2063    fn bitxor(self, rhs: Self) -> Self {
2064        self ^ rhs
2065    }
2066
2067    fn bitor(self, rhs: Self) -> Self {
2068        self | rhs
2069    }
2070
2071    fn bitand(self, rhs: Self) -> Self {
2072        self & rhs
2073    }
2074
2075    fn bitshiftleft(self, rhs: Self) -> Self {
2076        self.wrapping_shl(rhs as u32)
2077    }
2078
2079    fn bitshiftright(self, rhs: Self) -> Self {
2080        self.wrapping_shr(rhs as u32)
2081    }
2082
2083    fn and(self, rhs: Self) -> bool {
2084        self != 0 && rhs != 0
2085    }
2086
2087    fn max(self, rhs: Self) -> Self {
2088        Ord::max(self, rhs)
2089    }
2090
2091    fn max_value() -> Self {
2092        u8::MAX
2093    }
2094
2095    fn min_value() -> Self {
2096        u8::MIN
2097    }
2098
2099    fn is_equal(self, rhs: Self) -> bool {
2100        self == rhs
2101    }
2102
2103    fn epsilon() -> Self {
2104        0
2105    }
2106}
2107
2108impl Scalar for u16 {
2109    fn from_bf16(t: bf16) -> Self {
2110        t.to_f32() as u32 as u16
2111    }
2112
2113    fn from_f16(t: f16) -> Self {
2114        t.to_f32() as u32 as u16
2115    }
2116
2117    fn from_f32(t: f32) -> Self {
2118        t as u32 as u16
2119    }
2120
2121    fn from_f64(t: f64) -> Self {
2122        t as u32 as u16
2123    }
2124
2125    fn from_u8(t: u8) -> Self {
2126        t.into()
2127    }
2128
2129    fn from_u16(t: u16) -> Self {
2130        t
2131    }
2132
2133    fn from_u32(t: u32) -> Self {
2134        t.try_into().unwrap()
2135    }
2136
2137    fn from_u64(t: u64) -> Self {
2138        t.try_into().unwrap()
2139    }
2140
2141    fn from_i8(t: i8) -> Self {
2142        t.try_into().unwrap()
2143    }
2144
2145    fn from_i16(t: i16) -> Self {
2146        t.try_into().unwrap()
2147    }
2148
2149    fn from_i32(t: i32) -> Self {
2150        t.try_into().unwrap()
2151    }
2152
2153    fn from_i64(t: i64) -> Self {
2154        t.try_into().unwrap()
2155    }
2156
2157    fn from_bool(t: bool) -> Self {
2158        t.into()
2159    }
2160
2161    fn from_le_bytes(bytes: &[u8]) -> Self {
2162        Self::from_le_bytes([bytes[0], bytes[1]])
2163    }
2164
2165    fn to_ne_bytes(&self) -> &[u8] {
2166        let i: *const Self = self;
2167        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2168    }
2169
2170    fn dtype() -> DType {
2171        DType::U16
2172    }
2173
2174    fn zero() -> Self {
2175        0
2176    }
2177
2178    fn one() -> Self {
2179        1
2180    }
2181
2182    fn abs(self) -> Self {
2183        self
2184    }
2185
2186    fn neg(self) -> Self {
2187        self.wrapping_neg()
2188    }
2189
2190    fn exp2(self) -> Self {
2191        if self <= 31 { 2u32.pow(self as u32) as u16 } else { 65535 }
2192    }
2193
2194    fn log2(self) -> Self {
2195        self.ilog2() as u16
2196    }
2197
2198    fn relu(self) -> Self {
2199        self
2200    }
2201
2202    fn not(self) -> Self {
2203        u16::from(self == 0)
2204    }
2205
2206    fn nonzero(self) -> Self {
2207        u16::from(self != 0)
2208    }
2209
2210    fn add(self, rhs: Self) -> Self {
2211        self + rhs
2212    }
2213
2214    fn sub(self, rhs: Self) -> Self {
2215        self - rhs
2216    }
2217
2218    fn mul(self, rhs: Self) -> Self {
2219        self * rhs
2220    }
2221
2222    fn div(self, rhs: Self) -> Self {
2223        self / rhs
2224    }
2225
2226    fn pow(self, rhs: Self) -> Self {
2227        Self::pow(self, u32::from(rhs)) as u16
2228    }
2229
2230    fn mod_(self, rhs: Self) -> Self {
2231        self % rhs
2232    }
2233
2234    fn cmplt(self, rhs: Self) -> bool {
2235        self < rhs
2236    }
2237
2238    fn cmpgt(self, rhs: Self) -> bool {
2239        self > rhs
2240    }
2241
2242    fn noteq(self, rhs: Self) -> bool {
2243        self != rhs
2244    }
2245
2246    fn or(self, rhs: Self) -> bool {
2247        self != 0 || rhs != 0
2248    }
2249
2250    fn bitxor(self, rhs: Self) -> Self {
2251        self ^ rhs
2252    }
2253
2254    fn bitor(self, rhs: Self) -> Self {
2255        self | rhs
2256    }
2257
2258    fn bitand(self, rhs: Self) -> Self {
2259        self & rhs
2260    }
2261
2262    fn bitshiftleft(self, rhs: Self) -> Self {
2263        self.wrapping_shl(rhs as u32)
2264    }
2265
2266    fn bitshiftright(self, rhs: Self) -> Self {
2267        self.wrapping_shr(rhs as u32)
2268    }
2269
2270    fn and(self, rhs: Self) -> bool {
2271        self != 0 && rhs != 0
2272    }
2273
2274    fn max(self, rhs: Self) -> Self {
2275        Ord::max(self, rhs)
2276    }
2277
2278    fn max_value() -> Self {
2279        Self::MAX
2280    }
2281
2282    fn min_value() -> Self {
2283        Self::MIN
2284    }
2285
2286    fn is_equal(self, rhs: Self) -> bool {
2287        self == rhs
2288    }
2289
2290    fn epsilon() -> Self {
2291        0
2292    }
2293
2294    fn cast<T: Scalar>(self) -> T {
2295        use core::mem::transmute_copy as t;
2296        unsafe {
2297            match Self::dtype() {
2298                DType::BF16 => T::from_bf16(t(&self)),
2299                DType::F16 => T::from_f16(t(&self)),
2300                DType::F32 => T::from_f32(t(&self)),
2301                DType::F64 => T::from_f64(t(&self)),
2302                DType::U8 => T::from_u8(t(&self)),
2303                DType::U16 => T::from_u16(t(&self)),
2304                DType::U32 => T::from_u32(t(&self)),
2305                DType::U64 => T::from_u64(t(&self)),
2306                DType::I8 => T::from_i8(t(&self)),
2307                DType::I16 => T::from_i16(t(&self)),
2308                DType::I32 => T::from_i32(t(&self)),
2309                DType::I64 => T::from_i64(t(&self)),
2310                DType::Bool => T::from_bool(t(&self)),
2311            }
2312        }
2313    }
2314}
2315
2316impl Scalar for u32 {
2317    fn from_bf16(t: bf16) -> Self {
2318        t.to_f32() as u32
2319    }
2320
2321    fn from_f16(t: f16) -> Self {
2322        t.to_f32() as Self
2323    }
2324
2325    fn from_f32(t: f32) -> Self {
2326        t as Self
2327    }
2328
2329    fn from_f64(t: f64) -> Self {
2330        t as Self
2331    }
2332
2333    fn from_u8(t: u8) -> Self {
2334        t.into()
2335    }
2336
2337    fn from_u16(t: u16) -> Self {
2338        t.into()
2339    }
2340
2341    fn from_u32(t: u32) -> Self {
2342        t
2343    }
2344
2345    fn from_u64(t: u64) -> Self {
2346        t.try_into().unwrap()
2347    }
2348
2349    fn from_i8(t: i8) -> Self {
2350        t.try_into().unwrap()
2351    }
2352
2353    fn from_i16(t: i16) -> Self {
2354        t.try_into().unwrap()
2355    }
2356
2357    fn from_i32(t: i32) -> Self {
2358        t.try_into().unwrap()
2359    }
2360
2361    fn from_i64(t: i64) -> Self {
2362        t.try_into().unwrap()
2363    }
2364
2365    fn from_bool(t: bool) -> Self {
2366        t.into()
2367    }
2368
2369    fn from_le_bytes(bytes: &[u8]) -> Self {
2370        Self::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
2371    }
2372
2373    fn to_ne_bytes(&self) -> &[u8] {
2374        let i: *const Self = self;
2375        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2376    }
2377
2378    fn dtype() -> DType {
2379        DType::U32
2380    }
2381
2382    fn zero() -> Self {
2383        0
2384    }
2385
2386    fn one() -> Self {
2387        1
2388    }
2389
2390    fn abs(self) -> Self {
2391        self
2392    }
2393
2394    fn neg(self) -> Self {
2395        self.wrapping_neg()
2396    }
2397
2398    fn exp2(self) -> Self {
2399        if self <= 31 { 2u32.pow(self) } else { u32::MAX }
2400    }
2401
2402    fn log2(self) -> Self {
2403        self.ilog2()
2404    }
2405
2406    fn relu(self) -> Self {
2407        self
2408    }
2409
2410    fn not(self) -> Self {
2411        u32::from(self == 0)
2412    }
2413
2414    fn nonzero(self) -> Self {
2415        u32::from(self != 0)
2416    }
2417
2418    fn add(self, rhs: Self) -> Self {
2419        self.wrapping_add(rhs)
2420    }
2421
2422    fn sub(self, rhs: Self) -> Self {
2423        self.wrapping_sub(rhs)
2424    }
2425
2426    fn mul(self, rhs: Self) -> Self {
2427        self * rhs
2428    }
2429
2430    fn div(self, rhs: Self) -> Self {
2431        self / rhs
2432    }
2433
2434    fn pow(self, rhs: Self) -> Self {
2435        u32::pow(self, rhs)
2436    }
2437
2438    fn mod_(self, rhs: Self) -> Self {
2439        self % rhs
2440    }
2441
2442    fn cmplt(self, rhs: Self) -> bool {
2443        self < rhs
2444    }
2445
2446    fn cmpgt(self, rhs: Self) -> bool {
2447        self > rhs
2448    }
2449
2450    fn noteq(self, rhs: Self) -> bool {
2451        self != rhs
2452    }
2453
2454    fn or(self, rhs: Self) -> bool {
2455        self != 0 || rhs != 0
2456    }
2457
2458    fn bitxor(self, rhs: Self) -> Self {
2459        self ^ rhs
2460    }
2461
2462    fn bitor(self, rhs: Self) -> Self {
2463        self | rhs
2464    }
2465
2466    fn bitand(self, rhs: Self) -> Self {
2467        self & rhs
2468    }
2469
2470    fn bitshiftleft(self, rhs: Self) -> Self {
2471        self << rhs
2472    }
2473
2474    fn bitshiftright(self, rhs: Self) -> Self {
2475        self >> rhs
2476    }
2477
2478    fn and(self, rhs: Self) -> bool {
2479        self != 0 && rhs != 0
2480    }
2481
2482    fn max(self, rhs: Self) -> Self {
2483        Ord::max(self, rhs)
2484    }
2485
2486    fn max_value() -> Self {
2487        Self::MAX
2488    }
2489
2490    fn min_value() -> Self {
2491        Self::MIN
2492    }
2493
2494    fn is_equal(self, rhs: Self) -> bool {
2495        self == rhs
2496    }
2497
2498    fn epsilon() -> Self {
2499        0
2500    }
2501
2502    fn cast<T: Scalar>(self) -> T {
2503        use core::mem::transmute_copy as t;
2504        unsafe {
2505            match Self::dtype() {
2506                DType::BF16 => T::from_bf16(t(&self)),
2507                DType::F16 => T::from_f16(t(&self)),
2508                DType::F32 => T::from_f32(t(&self)),
2509                DType::F64 => T::from_f64(t(&self)),
2510                DType::U8 => T::from_u8(t(&self)),
2511                DType::U16 => T::from_u16(t(&self)),
2512                DType::U32 => T::from_u32(t(&self)),
2513                DType::U64 => T::from_u64(t(&self)),
2514                DType::I8 => T::from_i8(t(&self)),
2515                DType::I16 => T::from_i16(t(&self)),
2516                DType::I32 => T::from_i32(t(&self)),
2517                DType::I64 => T::from_i64(t(&self)),
2518                DType::Bool => T::from_bool(t(&self)),
2519            }
2520        }
2521    }
2522}
2523
2524impl Scalar for u64 {
2525    fn from_bf16(t: bf16) -> Self {
2526        t.to_f32() as u64
2527    }
2528
2529    fn from_f16(t: f16) -> Self {
2530        t.to_f32() as Self
2531    }
2532
2533    fn from_f32(t: f32) -> Self {
2534        t as u64
2535    }
2536
2537    fn from_f64(t: f64) -> Self {
2538        t as Self
2539    }
2540
2541    fn from_u8(t: u8) -> Self {
2542        t.into()
2543    }
2544
2545    fn from_u16(t: u16) -> Self {
2546        t.into()
2547    }
2548
2549    fn from_u32(t: u32) -> Self {
2550        t.into()
2551    }
2552
2553    fn from_u64(t: u64) -> Self {
2554        t
2555    }
2556
2557    fn from_i8(t: i8) -> Self {
2558        t.try_into().unwrap()
2559    }
2560
2561    fn from_i16(t: i16) -> Self {
2562        t.try_into().unwrap()
2563    }
2564
2565    fn from_i32(t: i32) -> Self {
2566        t.try_into().unwrap()
2567    }
2568
2569    fn from_i64(t: i64) -> Self {
2570        t.try_into().unwrap()
2571    }
2572
2573    fn from_bool(t: bool) -> Self {
2574        t.into()
2575    }
2576
2577    fn from_le_bytes(bytes: &[u8]) -> Self {
2578        Self::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
2579    }
2580
2581    fn to_ne_bytes(&self) -> &[u8] {
2582        let i: *const Self = self;
2583        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2584    }
2585
2586    fn dtype() -> DType {
2587        DType::U64
2588    }
2589
2590    fn zero() -> Self {
2591        0
2592    }
2593
2594    fn one() -> Self {
2595        1
2596    }
2597
2598    fn abs(self) -> Self {
2599        self
2600    }
2601
2602    fn neg(self) -> Self {
2603        self.wrapping_neg()
2604    }
2605
2606    fn exp2(self) -> Self {
2607        if self <= 63 { 2u64.pow(self as u32) } else { u64::MAX }
2608    }
2609
2610    fn log2(self) -> Self {
2611        self.ilog2() as u64
2612    }
2613
2614    fn relu(self) -> Self {
2615        self
2616    }
2617
2618    fn not(self) -> Self {
2619        u64::from(self == 0)
2620    }
2621
2622    fn nonzero(self) -> Self {
2623        u64::from(self != 0)
2624    }
2625
2626    fn add(self, rhs: Self) -> Self {
2627        self + rhs
2628    }
2629
2630    fn sub(self, rhs: Self) -> Self {
2631        self.wrapping_sub(rhs)
2632    }
2633
2634    fn mul(self, rhs: Self) -> Self {
2635        self * rhs
2636    }
2637
2638    fn div(self, rhs: Self) -> Self {
2639        self / rhs
2640    }
2641
2642    fn pow(self, rhs: Self) -> Self {
2643        i64::pow(self as i64, u32::try_from(rhs).unwrap()) as u64
2644    }
2645
2646    fn mod_(self, rhs: Self) -> Self {
2647        self % rhs
2648    }
2649
2650    fn cmplt(self, rhs: Self) -> bool {
2651        self < rhs
2652    }
2653
2654    fn cmpgt(self, rhs: Self) -> bool {
2655        self > rhs
2656    }
2657
2658    fn noteq(self, rhs: Self) -> bool {
2659        self != rhs
2660    }
2661
2662    fn or(self, rhs: Self) -> bool {
2663        self != 0 || rhs != 0
2664    }
2665
2666    fn bitxor(self, rhs: Self) -> Self {
2667        self ^ rhs
2668    }
2669
2670    fn bitor(self, rhs: Self) -> Self {
2671        self | rhs
2672    }
2673
2674    fn bitand(self, rhs: Self) -> Self {
2675        self & rhs
2676    }
2677
2678    fn bitshiftleft(self, rhs: Self) -> Self {
2679        self << rhs
2680    }
2681
2682    fn bitshiftright(self, rhs: Self) -> Self {
2683        self >> rhs
2684    }
2685
2686    fn and(self, rhs: Self) -> bool {
2687        self != 0 && rhs != 0
2688    }
2689
2690    fn max(self, rhs: Self) -> Self {
2691        Ord::max(self, rhs)
2692    }
2693
2694    fn max_value() -> Self {
2695        Self::MAX
2696    }
2697
2698    fn min_value() -> Self {
2699        Self::MIN
2700    }
2701
2702    fn is_equal(self, rhs: Self) -> bool {
2703        self == rhs
2704    }
2705
2706    fn epsilon() -> Self {
2707        0
2708    }
2709}
2710
2711impl Scalar for bool {
2712    fn from_bf16(t: bf16) -> Self {
2713        t != bf16::ZERO
2714    }
2715
2716    fn from_f16(t: f16) -> Self {
2717        t != f16::ZERO
2718    }
2719
2720    fn from_f32(t: f32) -> Self {
2721        t != 0.
2722    }
2723
2724    fn from_f64(t: f64) -> Self {
2725        t != 0.
2726    }
2727
2728    fn from_u8(t: u8) -> Self {
2729        t != 0
2730    }
2731
2732    fn from_u16(t: u16) -> Self {
2733        t != 0
2734    }
2735
2736    fn from_u32(t: u32) -> Self {
2737        t != 0
2738    }
2739
2740    fn from_u64(t: u64) -> Self {
2741        t != 0
2742    }
2743
2744    fn from_i8(t: i8) -> Self {
2745        t != 0
2746    }
2747
2748    fn from_i16(t: i16) -> Self {
2749        t != 0
2750    }
2751
2752    fn from_i32(t: i32) -> Self {
2753        t != 0
2754    }
2755
2756    fn from_i64(t: i64) -> Self {
2757        t != 0
2758    }
2759
2760    fn from_bool(t: bool) -> Self {
2761        t
2762    }
2763
2764    fn from_le_bytes(bytes: &[u8]) -> Self {
2765        bytes[0] != 0
2766    }
2767
2768    fn to_ne_bytes(&self) -> &[u8] {
2769        let i: *const Self = self;
2770        unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2771    }
2772
2773    fn dtype() -> DType {
2774        DType::Bool
2775    }
2776
2777    fn zero() -> Self {
2778        false
2779    }
2780
2781    fn one() -> Self {
2782        true
2783    }
2784
2785    fn abs(self) -> Self {
2786        self
2787    }
2788
2789    fn neg(self) -> Self {
2790        panic!()
2791    }
2792
2793    fn exp2(self) -> Self {
2794        panic!()
2795    }
2796
2797    fn log2(self) -> Self {
2798        panic!()
2799    }
2800
2801    fn relu(self) -> Self {
2802        panic!()
2803    }
2804
2805    fn not(self) -> Self {
2806        !self
2807    }
2808
2809    fn nonzero(self) -> Self {
2810        self
2811    }
2812
2813    fn add(self, rhs: Self) -> Self {
2814        self | rhs
2815    }
2816
2817    fn sub(self, rhs: Self) -> Self {
2818        let _ = rhs;
2819        panic!()
2820    }
2821
2822    fn mul(self, rhs: Self) -> Self {
2823        self & rhs
2824    }
2825
2826    fn div(self, rhs: Self) -> Self {
2827        let _ = rhs;
2828        panic!()
2829    }
2830
2831    fn pow(self, rhs: Self) -> Self {
2832        let _ = rhs;
2833        panic!()
2834    }
2835
2836    fn mod_(self, rhs: Self) -> Self {
2837        let _ = rhs;
2838        panic!()
2839    }
2840
2841    fn cmplt(self, rhs: Self) -> Self {
2842        !self & rhs
2843    }
2844
2845    fn cmpgt(self, rhs: Self) -> Self {
2846        self && !rhs
2847    }
2848
2849    fn noteq(self, rhs: Self) -> bool {
2850        self != rhs
2851    }
2852
2853    fn or(self, rhs: Self) -> Self {
2854        self || rhs
2855    }
2856
2857    fn bitxor(self, rhs: Self) -> Self {
2858        self ^ rhs
2859    }
2860
2861    fn bitor(self, rhs: Self) -> Self {
2862        self | rhs
2863    }
2864
2865    fn bitand(self, rhs: Self) -> Self {
2866        self & rhs
2867    }
2868
2869    fn bitshiftleft(self, _rhs: Self) -> Self {
2870        self
2871    }
2872
2873    fn bitshiftright(self, _rhs: Self) -> Self {
2874        false
2875    }
2876
2877    fn and(self, rhs: Self) -> bool {
2878        self && rhs
2879    }
2880
2881    fn max(self, rhs: Self) -> Self {
2882        <bool as Ord>::max(self, rhs)
2883    }
2884
2885    fn max_value() -> Self {
2886        true
2887    }
2888
2889    fn min_value() -> Self {
2890        false
2891    }
2892
2893    fn is_equal(self, rhs: Self) -> bool {
2894        self == rhs
2895    }
2896
2897    fn epsilon() -> Self {
2898        false
2899    }
2900
2901    fn cast<T: Scalar>(self) -> T {
2902        use core::mem::transmute_copy as t;
2903        unsafe {
2904            match Self::dtype() {
2905                DType::BF16 => T::from_bf16(t(&self)),
2906                DType::F16 => T::from_f16(t(&self)),
2907                DType::F32 => T::from_f32(t(&self)),
2908                DType::F64 => T::from_f64(t(&self)),
2909                DType::U8 => T::from_u8(t(&self)),
2910                DType::U16 => T::from_u16(t(&self)),
2911                DType::U32 => T::from_u32(t(&self)),
2912                DType::U64 => T::from_u64(t(&self)),
2913                DType::I8 => T::from_i8(t(&self)),
2914                DType::I16 => T::from_i16(t(&self)),
2915                DType::I32 => T::from_i32(t(&self)),
2916                DType::I64 => T::from_i64(t(&self)),
2917                DType::Bool => T::from_bool(t(&self)),
2918            }
2919        }
2920    }
2921}