1use crate::dtype::DType;
7use half::{bf16, f16};
8
9pub trait Scalar: Copy + Clone + Sized + core::fmt::Debug + 'static + PartialEq + Send + Sync + PartialOrd {
11 #[must_use]
13 fn from_bf16(t: bf16) -> Self;
14 #[must_use]
16 fn from_f16(t: f16) -> Self;
17 #[must_use]
19 fn from_f32(t: f32) -> Self;
20 #[must_use]
22 fn from_f64(t: f64) -> Self;
23 #[must_use]
25 fn from_u8(t: u8) -> Self;
26 #[must_use]
28 fn from_u16(t: u16) -> Self;
29 #[must_use]
31 fn from_u32(t: u32) -> Self;
32 #[must_use]
34 fn from_u64(t: u64) -> Self;
35 #[must_use]
37 fn from_i8(t: i8) -> Self;
38 fn from_i16(t: i16) -> Self;
40 #[must_use]
41 fn from_i32(t: i32) -> Self;
43 #[must_use]
45 fn from_i64(t: i64) -> Self;
46 #[must_use]
48 fn from_bool(t: bool) -> Self;
49 #[must_use]
51 fn from_le_bytes(bytes: &[u8]) -> Self;
52 #[must_use]
54 #[allow(clippy::ptr_as_ptr)]
55 fn to_ne_bytes(&self) -> &[u8];
56 #[must_use]
58 fn bit_size() -> u8 {
59 Self::dtype().bit_size()
60 }
61 #[must_use]
63 fn dtype() -> DType;
64 #[must_use]
66 fn zero() -> Self;
67 #[must_use]
69 fn one() -> Self;
70 #[must_use]
72 fn abs(self) -> Self;
73 #[must_use]
75 fn neg(self) -> Self;
76 #[must_use]
78 fn exp2(self) -> Self;
79 #[must_use]
81 fn log2(self) -> Self;
82 #[must_use]
84 fn relu(self) -> Self;
85 #[must_use]
87 fn not(self) -> Self;
88 #[must_use]
90 fn nonzero(self) -> Self;
91 #[must_use]
93 fn add(self, rhs: Self) -> Self;
94 #[must_use]
96 fn sub(self, rhs: Self) -> Self;
97 #[must_use]
99 fn mul(self, rhs: Self) -> Self;
100 #[must_use]
102 fn div(self, rhs: Self) -> Self;
103 #[must_use]
105 fn pow(self, rhs: Self) -> Self;
106 #[must_use]
108 fn mod_(self, rhs: Self) -> Self;
109 #[must_use]
111 fn cmplt(self, rhs: Self) -> bool;
112 #[must_use]
114 fn cmpgt(self, rhs: Self) -> bool;
115 #[must_use]
117 fn noteq(self, rhs: Self) -> bool;
118 #[must_use]
120 fn or(self, rhs: Self) -> bool;
121 #[must_use]
123 fn bitxor(self, rhs: Self) -> Self;
124 #[must_use]
126 fn bitor(self, rhs: Self) -> Self;
127 #[must_use]
129 fn bitand(self, rhs: Self) -> Self;
130 #[must_use]
132 fn bitshiftleft(self, rhs: Self) -> Self;
133 #[must_use]
135 fn bitshiftright(self, rhs: Self) -> Self;
136 #[must_use]
138 fn and(self, rhs: Self) -> bool;
139 #[must_use]
141 fn max(self, rhs: Self) -> Self;
142 #[must_use]
144 fn max_value() -> Self;
145 #[must_use]
147 fn min_value() -> Self;
148 #[must_use]
151 fn is_equal(self, rhs: Self) -> bool;
152 #[must_use]
154 fn cast<T: Scalar>(self) -> T {
155 use core::mem::transmute_copy as t;
156 unsafe {
157 match Self::dtype() {
158 DType::BF16 => T::from_bf16(t(&self)),
159 DType::F16 => T::from_f16(t(&self)),
160 DType::F32 => T::from_f32(t(&self)),
161 DType::F64 => T::from_f64(t(&self)),
162 DType::U8 => T::from_u8(t(&self)),
163 DType::U16 => T::from_u16(t(&self)),
164 DType::U32 => T::from_u32(t(&self)),
165 DType::U64 => T::from_u64(t(&self)),
166 DType::I8 => T::from_i8(t(&self)),
167 DType::I16 => T::from_i16(t(&self)),
168 DType::I32 => T::from_i32(t(&self)),
169 DType::I64 => T::from_i64(t(&self)),
170 DType::Bool => T::from_bool(t(&self)),
171 }
172 }
173 }
174 #[must_use]
176 fn epsilon() -> Self {
177 Self::zero()
178 }
179}
180
181pub trait Float: Scalar {
183 #[must_use]
185 fn floor(self) -> Self;
186 #[must_use]
188 fn reciprocal(self) -> Self;
189 #[must_use]
191 fn sin(self) -> Self;
192 #[must_use]
194 fn cos(self) -> Self;
195 #[must_use]
197 fn sqrt(self) -> Self;
198 #[must_use]
200 fn trunc(self) -> Self;
201}
202
203impl Scalar for bf16 {
204 fn from_bf16(t: bf16) -> Self {
205 t
206 }
207
208 fn from_f16(t: f16) -> Self {
209 bf16::from_f32(t.into())
210 }
211
212 fn from_f32(t: f32) -> Self {
213 bf16::from_f32(t)
214 }
215
216 fn from_f64(t: f64) -> Self {
217 bf16::from_f64(t)
218 }
219
220 fn from_u8(t: u8) -> Self {
221 bf16::from_f32(f32::from(t))
222 }
223
224 fn from_u16(t: u16) -> Self {
225 bf16::from_f32(f32::from(t))
226 }
227
228 fn from_u32(t: u32) -> Self {
229 bf16::from_f64(f64::from(t))
230 }
231
232 #[allow(clippy::cast_precision_loss)]
233 fn from_u64(t: u64) -> Self {
234 bf16::from_f64(t as f64)
235 }
236
237 fn from_i8(t: i8) -> Self {
238 bf16::from_f32(f32::from(t))
239 }
240
241 fn from_i16(t: i16) -> Self {
242 bf16::from_f32(f32::from(t))
243 }
244
245 #[allow(clippy::cast_possible_truncation)]
246 fn from_i32(t: i32) -> Self {
247 bf16::from_f32(t as f32)
248 }
249
250 #[allow(clippy::cast_possible_truncation)]
251 fn from_i64(t: i64) -> Self {
252 bf16::from_f32(t as f32)
253 }
254
255 fn from_bool(t: bool) -> Self {
256 bf16::from_f32(f32::from(t))
257 }
258
259 fn from_le_bytes(bytes: &[u8]) -> Self {
260 bf16::from_le_bytes([bytes[0], bytes[1]])
261 }
262
263 fn to_ne_bytes(&self) -> &[u8] {
264 let x: *const Self = self;
265 unsafe { std::slice::from_raw_parts(x.cast(), 2) }
266 }
267
268 fn dtype() -> DType {
269 DType::BF16
270 }
271
272 fn zero() -> Self {
273 bf16::ZERO
274 }
275
276 fn one() -> Self {
277 bf16::ONE
278 }
279
280 fn abs(self) -> Self {
281 self.max(-self)
282 }
283
284 fn neg(self) -> Self {
285 -self
286 }
287
288 fn exp2(self) -> Self {
289 bf16::from_f64(f64::from(self).exp2())
290 }
291
292 fn log2(self) -> Self {
293 bf16::from_f64(f64::from(self).log2())
294 }
295
296 fn relu(self) -> Self {
297 Scalar::max(self, Self::ZERO)
298 }
299
300 fn not(self) -> Self {
301 bf16::from_f32(if f64::from(self) == 0.0 { 0.0 } else { 1.0 })
302 }
303
304 fn nonzero(self) -> Self {
305 bf16::from_f32(if f64::from(self) == 0.0 { 0.0 } else { 1.0 })
306 }
307
308 fn add(self, rhs: Self) -> Self {
309 self + rhs
310 }
311
312 fn sub(self, rhs: Self) -> Self {
313 self - rhs
314 }
315
316 fn mul(self, rhs: Self) -> Self {
317 self * rhs
318 }
319
320 fn div(self, rhs: Self) -> Self {
321 self / rhs
322 }
323
324 fn pow(self, rhs: Self) -> Self {
325 bf16::from_f64(f64::from(self).powf(f64::from(rhs)))
326 }
327
328 fn mod_(self, rhs: Self) -> Self {
329 self % rhs
330 }
331
332 fn cmplt(self, rhs: Self) -> bool {
333 self < rhs
334 }
335
336 fn cmpgt(self, rhs: Self) -> bool {
337 self > rhs
338 }
339
340 fn noteq(self, rhs: Self) -> bool {
341 self != rhs
342 }
343
344 fn or(self, rhs: Self) -> bool {
345 self != Self::ZERO || rhs != Self::ZERO
346 }
347
348 fn bitxor(self, rhs: Self) -> Self {
349 let a = f64::from(self) as i64;
350 let b = f64::from(rhs) as i64;
351 bf16::from_f32((a ^ b) as f32)
352 }
353
354 fn bitor(self, rhs: Self) -> Self {
355 let a = f64::from(self) as i64;
356 let b = f64::from(rhs) as i64;
357 bf16::from_f32((a | b) as f32)
358 }
359
360 fn bitand(self, rhs: Self) -> Self {
361 let a = f64::from(self) as i64;
362 let b = f64::from(rhs) as i64;
363 bf16::from_f32((a & b) as f32)
364 }
365
366 fn bitshiftleft(self, rhs: Self) -> Self {
367 let a = f64::from(self) as i64;
368 let b = f64::from(rhs) as i64;
369 bf16::from_f32((a << b) as f32)
370 }
371
372 fn bitshiftright(self, rhs: Self) -> Self {
373 let a = f64::from(self) as i64;
374 let b = f64::from(rhs) as i64;
375 bf16::from_f32((a >> b) as f32)
376 }
377
378 fn and(self, rhs: Self) -> bool {
379 self != Self::ZERO && rhs != Self::ZERO
380 }
381
382 fn max(self, rhs: Self) -> Self {
383 self.max(rhs)
384 }
385
386 fn max_value() -> Self {
387 bf16::MAX
388 }
389
390 fn min_value() -> Self {
391 bf16::MIN
392 }
393
394 fn is_equal(self, rhs: Self) -> bool {
395 self == rhs
396 }
397
398 fn epsilon() -> Self {
399 bf16::MIN_POSITIVE
400 }
401}
402
403impl Float for bf16 {
404 fn reciprocal(self) -> Self {
405 bf16::ONE / self
406 }
407
408 fn floor(self) -> Self {
409 bf16::from_f32(self.to_f32().floor())
410 }
411
412 fn sin(self) -> Self {
413 bf16::from_f32(self.to_f32().sin())
414 }
415
416 fn cos(self) -> Self {
417 bf16::from_f32(self.to_f32().cos())
418 }
419
420 fn sqrt(self) -> Self {
421 bf16::from_f32(self.to_f32().sqrt())
422 }
423
424 fn trunc(self) -> Self {
425 bf16::from_f32(self.to_f32().trunc())
426 }
427}
428
429impl Scalar for f16 {
430 fn from_bf16(t: bf16) -> Self {
431 f16::from_f32(t.to_f32())
432 }
433
434 fn from_f16(t: f16) -> Self {
435 f16::from_f32(t.to_f32())
436 }
437
438 fn from_f32(t: f32) -> Self {
439 f16::from_f32(t)
440 }
441
442 fn from_f64(t: f64) -> Self {
443 f16::from_f64(t)
444 }
445
446 fn from_u8(t: u8) -> Self {
447 f16::from_f32(t as f32)
448 }
449
450 fn from_u16(t: u16) -> Self {
451 f16::from_f32(t as f32)
452 }
453
454 fn from_u32(t: u32) -> Self {
455 f16::from_f64(t.into())
456 }
457
458 fn from_u64(t: u64) -> Self {
459 f16::from_f64(t as f64)
460 }
461
462 fn from_i8(t: i8) -> Self {
463 f16::from_f32(t as f32)
464 }
465
466 fn from_i16(t: i16) -> Self {
467 f16::from_f32(t as f32)
468 }
469
470 #[allow(clippy::cast_lossless)]
471 fn from_i32(t: i32) -> Self {
472 f16::from_f64(t as f64)
473 }
474
475 #[allow(clippy::cast_precision_loss)]
476 fn from_i64(t: i64) -> Self {
477 f16::from_f64(t as f64)
478 }
479
480 #[allow(clippy::cast_lossless)]
481 fn from_bool(t: bool) -> Self {
482 f16::from_f64(t as i8 as f64)
483 }
484
485 fn from_le_bytes(bytes: &[u8]) -> Self {
486 f16::from_le_bytes([bytes[0], bytes[1]])
487 }
488
489 fn to_ne_bytes(&self) -> &[u8] {
490 let i: *const Self = self;
491 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
492 }
493
494 fn dtype() -> DType {
495 DType::F16
496 }
497
498 fn zero() -> Self {
499 f16::ZERO
500 }
501
502 fn one() -> Self {
503 f16::ONE
504 }
505
506 fn abs(self) -> Self {
507 self.max(-self)
508 }
509
510 fn neg(self) -> Self {
511 -self
512 }
513
514 fn exp2(self) -> Self {
515 f16::from_f32(self.to_f32().exp2())
516 }
517
518 fn log2(self) -> Self {
519 f16::from_f32(self.to_f32().log2())
520 }
521
522 fn relu(self) -> Self {
523 self.max(f16::ZERO)
524 }
525
526 fn not(self) -> Self {
527 f16::from_f32(if f32::from(self) == 0.0 { 0.0 } else { 1.0 })
528 }
529
530 fn nonzero(self) -> Self {
531 f16::from_f32(if f32::from(self) == 0.0 { 0.0 } else { 1.0 })
532 }
533
534 fn add(self, rhs: Self) -> Self {
535 self + rhs
536 }
537
538 fn sub(self, rhs: Self) -> Self {
539 self - rhs
540 }
541
542 fn mul(self, rhs: Self) -> Self {
543 self * rhs
544 }
545
546 fn div(self, rhs: Self) -> Self {
547 self / rhs
548 }
549
550 fn pow(self, rhs: Self) -> Self {
551 f16::from_f32(self.to_f32().pow(rhs.to_f32()))
552 }
553
554 fn mod_(self, rhs: Self) -> Self {
555 self % rhs
556 }
557
558 fn cmplt(self, rhs: Self) -> bool {
559 self < rhs
560 }
561
562 fn cmpgt(self, rhs: Self) -> bool {
563 self > rhs
564 }
565
566 fn noteq(self, rhs: Self) -> bool {
567 self != rhs
568 }
569
570 fn or(self, rhs: Self) -> bool {
571 self != f16::ZERO || rhs != f16::ZERO
572 }
573
574 fn bitxor(self, rhs: Self) -> Self {
575 let ix = self.to_bits() ^ rhs.to_bits();
576 f16::from_le_bytes([ix as u8, (ix >> 8) as u8])
577 }
578
579 fn bitor(self, rhs: Self) -> Self {
580 let ix = self.to_bits() | rhs.to_bits();
581 f16::from_le_bytes([ix as u8, (ix >> 8) as u8])
582 }
583
584 fn bitand(self, rhs: Self) -> Self {
585 let ix = self.to_bits() & rhs.to_bits();
586 f16::from_le_bytes([ix as u8, (ix >> 8) as u8])
587 }
588
589 fn bitshiftleft(self, rhs: Self) -> Self {
590 let lhs_f32 = self.to_f32();
591 let rhs_f32 = rhs.to_f32();
592 let lhs_bits = lhs_f32.to_bits() as i32;
593 let rhs_bits = rhs_f32.to_bits() as i32;
594 let result = f32::from_bits((lhs_bits << rhs_bits) as u32);
595 f16::from_f32(result)
596 }
597
598 fn bitshiftright(self, rhs: Self) -> Self {
599 let lhs_f32 = self.to_f32();
600 let rhs_f32 = rhs.to_f32();
601 let lhs_bits = lhs_f32.to_bits() as i32;
602 let rhs_bits = rhs_f32.to_bits() as i32;
603 let result = f32::from_bits((lhs_bits >> rhs_bits) as u32);
604 f16::from_f32(result)
605 }
606
607 fn and(self, rhs: Self) -> bool {
608 self != f16::ZERO && rhs != f16::ZERO
609 }
610
611 fn max(self, rhs: Self) -> Self {
612 f16::max(self, rhs)
613 }
614
615 fn max_value() -> Self {
616 f16::MAX
617 }
618
619 fn min_value() -> Self {
620 f16::MIN
621 }
622
623 fn is_equal(self, rhs: Self) -> bool {
624 (self == -Self::INFINITY && rhs == -Self::INFINITY)
625 || (self.is_nan() && rhs.is_nan())
626 || self.sub(rhs).abs() < self.abs() * f16::from_f32(0.0001)
627 }
628
629 fn epsilon() -> Self {
630 f16::from_f32(0.00001)
631 }
632}
633
634impl Float for f16 {
635 fn reciprocal(self) -> Self {
636 f16::ONE / self
637 }
638
639 fn sin(self) -> Self {
640 f16::from_f32(self.to_f32().sin())
641 }
642
643 fn cos(self) -> Self {
644 f16::from_f32(self.to_f32().cos())
645 }
646
647 fn sqrt(self) -> Self {
648 f16::from_f32(self.to_f32().sqrt())
649 }
650
651 fn floor(self) -> Self {
652 f16::from_f32(self.to_f32().floor())
653 }
654
655 fn trunc(self) -> Self {
656 f16::from_f32(self.to_f32().trunc())
657 }
658}
659
660impl Scalar for f32 {
661 fn from_bf16(t: bf16) -> Self {
662 t.into()
663 }
664
665 fn from_f16(t: f16) -> Self {
666 t.into()
667 }
668
669 fn from_f32(t: f32) -> Self {
670 t
671 }
672
673 #[allow(clippy::cast_possible_truncation)]
674 fn from_f64(t: f64) -> Self {
675 t as Self
676 }
677
678 fn from_u8(t: u8) -> Self {
679 f32::from(t)
680 }
681
682 fn from_u16(t: u16) -> Self {
683 t.into()
684 }
685
686 #[allow(clippy::cast_precision_loss)]
687 fn from_u32(t: u32) -> Self {
688 t as f32
689 }
690
691 #[allow(clippy::cast_precision_loss)]
692 fn from_u64(t: u64) -> Self {
693 t as f32
694 }
695
696 fn from_i8(t: i8) -> Self {
697 f32::from(t)
698 }
699
700 fn from_i16(t: i16) -> Self {
701 f32::from(t)
702 }
703
704 #[allow(clippy::cast_precision_loss)]
705 fn from_i32(t: i32) -> Self {
706 t as f32
707 }
708
709 #[allow(clippy::cast_precision_loss)]
710 fn from_i64(t: i64) -> Self {
711 t as f32
712 }
713
714 fn from_bool(t: bool) -> Self {
715 f32::from(i8::from(t))
716 }
717
718 fn from_le_bytes(bytes: &[u8]) -> Self {
719 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
720 }
721
722 fn to_ne_bytes(&self) -> &[u8] {
723 let i: *const Self = self;
724 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
725 }
726
727 fn dtype() -> DType {
728 DType::F32
729 }
730
731 fn zero() -> Self {
732 0.
733 }
734
735 fn one() -> Self {
736 1.
737 }
738
739 fn abs(self) -> Self {
740 self.abs()
741 }
742
743 fn neg(self) -> Self {
744 -self
745 }
746
747 fn exp2(self) -> Self {
748 f32::exp2(self)
749 }
750
751 fn log2(self) -> Self {
752 self.log2()
753 }
754
755 fn relu(self) -> Self {
756 self.max(0.)
757 }
758
759 fn not(self) -> Self {
760 if self == 0. { 1. } else { 0. }
761 }
762
763 fn nonzero(self) -> Self {
764 f32::from(i8::from(self != 0.))
765 }
766
767 fn add(self, rhs: Self) -> Self {
768 self + rhs
769 }
770
771 fn sub(self, rhs: Self) -> Self {
772 self - rhs
773 }
774
775 fn mul(self, rhs: Self) -> Self {
776 self * rhs
777 }
778
779 fn div(self, rhs: Self) -> Self {
780 self / rhs
781 }
782
783 fn pow(self, rhs: Self) -> Self {
784 self.powf(rhs)
785 }
786
787 fn mod_(self, rhs: Self) -> Self {
788 self % rhs
789 }
790
791 fn cmplt(self, rhs: Self) -> bool {
792 self < rhs
793 }
794
795 fn cmpgt(self, rhs: Self) -> bool {
796 self > rhs
797 }
798
799 fn noteq(self, rhs: Self) -> bool {
800 !self.is_equal(rhs)
801 }
802
803 fn or(self, rhs: Self) -> bool {
804 self != 0. || rhs != 0.
805 }
806
807 fn bitxor(self, rhs: Self) -> Self {
808 let rhs_bits = rhs.to_bits();
809 f32::from_bits(self.to_bits() ^ rhs_bits)
810 }
811
812 fn bitor(self, rhs: Self) -> Self {
813 let rhs_bits = rhs.to_bits();
814 f32::from_bits(self.to_bits() | rhs_bits)
815 }
816
817 fn bitand(self, rhs: Self) -> Self {
818 let rhs_bits = rhs.to_bits();
819 f32::from_bits(self.to_bits() & rhs_bits)
820 }
821
822 fn bitshiftleft(self, rhs: Self) -> Self {
823 let rhs_shift = (rhs.to_bits() & 0xFF) as u32;
824 let ix = (self.to_bits() as u64) << rhs_shift;
825 f32::from_bits(ix as u32)
826 }
827
828 fn bitshiftright(self, rhs: Self) -> Self {
829 let rhs_shift = (rhs.to_bits() & 0xFF) as u32;
830 let ix = (self.to_bits() as u32) >> rhs_shift;
831 f32::from_bits(ix)
832 }
833
834 fn and(self, rhs: Self) -> bool {
835 self != 0. && rhs != 0.
836 }
837
838 fn max(self, rhs: Self) -> Self {
839 f32::max(self, rhs)
840 }
841
842 fn max_value() -> Self {
843 f32::MAX
844 }
845
846 fn min_value() -> Self {
847 f32::MIN
848 }
849
850 fn is_equal(self, rhs: Self) -> bool {
851 let a = self;
852 let b = rhs;
853 if a.is_nan() && b.is_nan() {
854 return true;
855 }
856 #[allow(clippy::float_cmp)]
857 if a == b {
858 return true;
859 }
860 let diff = (a - b).abs();
861 let max_abs = a.abs().max(b.abs());
862 let rel_tol = 1e-3 * max_abs; let abs_tol = 2e-7; diff < rel_tol || diff < abs_tol
865 }
866
867 fn epsilon() -> Self {
868 0.0001
869 }
870}
871
872impl Float for f32 {
873 fn sin(self) -> Self {
874 f32::sin(self)
879 }
880
881 fn floor(self) -> Self {
882 f32::floor(self)
883 }
884
885 fn cos(self) -> Self {
886 f32::cos(self)
894 }
895
896 fn sqrt(self) -> Self {
897 f32::sqrt(self)
904 }
905
906 fn reciprocal(self) -> Self {
907 1.0 / self
908 }
909
910 fn trunc(self) -> Self {
911 f32::trunc(self)
912 }
913}
914
915impl Scalar for f64 {
916 fn from_bf16(t: bf16) -> Self {
917 t.into()
918 }
919
920 fn from_f16(t: f16) -> Self {
921 t.into()
922 }
923
924 fn from_f32(t: f32) -> Self {
925 f64::from(t)
926 }
927
928 fn from_f64(t: f64) -> Self {
929 t
930 }
931
932 fn from_u8(t: u8) -> Self {
933 f64::from(t)
934 }
935
936 fn from_u16(t: u16) -> Self {
937 t.into()
938 }
939
940 fn from_u32(t: u32) -> Self {
941 t.into()
942 }
943
944 #[allow(clippy::cast_precision_loss)]
945 fn from_u64(t: u64) -> Self {
946 t as f64
947 }
948
949 fn from_i8(t: i8) -> Self {
950 t.into()
951 }
952
953 fn from_i16(t: i16) -> Self {
954 t.into()
955 }
956
957 fn from_i32(t: i32) -> Self {
958 t.into()
959 }
960
961 #[allow(clippy::cast_precision_loss)]
962 fn from_i64(t: i64) -> Self {
963 t as f64
964 }
965
966 fn from_bool(t: bool) -> Self {
967 t.into()
968 }
969
970 fn from_le_bytes(bytes: &[u8]) -> Self {
971 f64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
972 }
973
974 fn to_ne_bytes(&self) -> &[u8] {
975 let i: *const Self = self;
976 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
977 }
978
979 fn dtype() -> DType {
980 DType::F64
981 }
982
983 fn zero() -> Self {
984 0.
985 }
986
987 fn one() -> Self {
988 1.
989 }
990
991 fn abs(self) -> Self {
992 self.abs()
993 }
994
995 fn neg(self) -> Self {
996 -self
997 }
998
999 fn exp2(self) -> Self {
1000 f64::exp2(self)
1001 }
1002
1003 fn log2(self) -> Self {
1004 self.log2()
1005 }
1006
1007 fn relu(self) -> Self {
1008 self.max(0.)
1009 }
1010
1011 fn not(self) -> Self {
1012 if self == 0. { 1. } else { 0. }
1013 }
1014
1015 fn nonzero(self) -> Self {
1016 u8::from(self != 0.).into()
1017 }
1018
1019 fn add(self, rhs: Self) -> Self {
1020 self + rhs
1021 }
1022
1023 fn sub(self, rhs: Self) -> Self {
1024 self - rhs
1025 }
1026
1027 fn mul(self, rhs: Self) -> Self {
1028 self * rhs
1029 }
1030
1031 fn div(self, rhs: Self) -> Self {
1032 self / rhs
1033 }
1034
1035 fn pow(self, rhs: Self) -> Self {
1036 self.powf(rhs)
1037 }
1038
1039 fn mod_(self, rhs: Self) -> Self {
1040 self % rhs
1041 }
1042
1043 fn cmplt(self, rhs: Self) -> bool {
1044 self < rhs
1045 }
1046
1047 fn cmpgt(self, rhs: Self) -> bool {
1048 self > rhs
1049 }
1050
1051 fn noteq(self, rhs: Self) -> bool {
1052 !self.is_equal(rhs)
1053 }
1054
1055 fn or(self, rhs: Self) -> bool {
1056 self != 0. || rhs != 0.
1057 }
1058
1059 fn bitxor(self, rhs: Self) -> Self {
1060 f64::from_bits(self.to_bits() ^ rhs.to_bits())
1061 }
1062
1063 fn bitor(self, rhs: Self) -> Self {
1064 f64::from_bits(self.to_bits() | rhs.to_bits())
1065 }
1066
1067 fn bitand(self, rhs: Self) -> Self {
1068 f64::from_bits(self.to_bits() & rhs.to_bits())
1069 }
1070
1071 fn bitshiftleft(self, rhs: Self) -> Self {
1072 let rhs_shift = (rhs.to_bits() & 0xFF) as u32;
1073 let ix = (self.to_bits() as u64) << rhs_shift;
1074 f64::from_bits(ix)
1075 }
1076
1077 fn bitshiftright(self, rhs: Self) -> Self {
1078 let rhs_shift = (rhs.to_bits() & 0xFF) as i32;
1079 let ix = ((self.to_bits() as u64) >> rhs_shift) as u32;
1080 f64::from_bits(ix as u64)
1081 }
1082
1083 fn and(self, rhs: Self) -> bool {
1084 self != 0. && rhs != 0.
1085 }
1086
1087 fn max(self, rhs: Self) -> Self {
1088 f64::max(self, rhs)
1089 }
1090
1091 fn max_value() -> Self {
1092 f64::MAX
1093 }
1094
1095 fn min_value() -> Self {
1096 f64::MIN
1097 }
1098
1099 fn is_equal(self, rhs: Self) -> bool {
1100 (self == -f64::INFINITY && rhs == -f64::INFINITY) || (self - rhs).abs() <= self.abs() * 0.001
1102 }
1103
1104 fn epsilon() -> Self {
1105 0.00001
1106 }
1107}
1108
1109impl Float for f64 {
1110 fn reciprocal(self) -> Self {
1111 1.0 / self
1112 }
1113
1114 fn floor(self) -> Self {
1115 self.floor()
1116 }
1117
1118 fn sin(self) -> Self {
1119 f64::sin(self)
1120 }
1121
1122 fn cos(self) -> Self {
1123 f64::cos(self)
1124 }
1125
1126 fn sqrt(self) -> Self {
1127 f64::sqrt(self)
1128 }
1129
1130 fn trunc(self) -> Self {
1131 self.trunc()
1132 }
1133}
1134
1135impl Scalar for i8 {
1136 #[allow(clippy::cast_possible_truncation)]
1137 fn from_bf16(t: bf16) -> Self {
1138 let t: f32 = t.into();
1139 t as Self
1140 }
1141
1142 #[allow(clippy::cast_possible_truncation)]
1143 fn from_f16(t: f16) -> Self {
1144 let t: f32 = t.into();
1145 t as Self
1146 }
1147
1148 #[allow(clippy::cast_possible_truncation)]
1149 fn from_f32(t: f32) -> Self {
1150 t as Self
1151 }
1152
1153 #[allow(clippy::cast_possible_truncation)]
1154 fn from_f64(t: f64) -> Self {
1155 t as Self
1156 }
1157
1158 fn from_u8(t: u8) -> Self {
1159 t.try_into().unwrap()
1160 }
1161
1162 fn from_u16(t: u16) -> Self {
1163 t.try_into().unwrap()
1164 }
1165
1166 fn from_u32(t: u32) -> Self {
1167 t.try_into().unwrap()
1168 }
1169
1170 fn from_u64(t: u64) -> Self {
1171 t.try_into().unwrap()
1172 }
1173
1174 fn from_i8(t: i8) -> Self {
1175 t
1176 }
1177
1178 fn from_i16(t: i16) -> Self {
1179 Self::try_from(t).unwrap()
1180 }
1181
1182 fn from_i32(t: i32) -> Self {
1183 Self::try_from(t).unwrap()
1184 }
1185
1186 fn from_i64(t: i64) -> Self {
1187 Self::try_from(t).unwrap()
1188 }
1189
1190 fn from_bool(t: bool) -> Self {
1191 Self::from(t)
1192 }
1193
1194 fn from_le_bytes(bytes: &[u8]) -> Self {
1195 i8::from_le_bytes([bytes[0]])
1196 }
1197
1198 fn to_ne_bytes(&self) -> &[u8] {
1199 let i: *const Self = self;
1200 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1201 }
1202
1203 fn dtype() -> DType {
1204 DType::I8
1205 }
1206
1207 fn zero() -> Self {
1208 0
1209 }
1210
1211 fn one() -> Self {
1212 1
1213 }
1214
1215 fn abs(self) -> Self {
1216 self.abs()
1217 }
1218
1219 fn neg(self) -> Self {
1220 -self
1221 }
1222
1223 fn exp2(self) -> Self {
1224 2i32.pow(self as u32) as i8
1225 }
1226
1227 fn log2(self) -> Self {
1228 f64::from(self).log2() as i8
1229 }
1230
1231 fn relu(self) -> Self {
1232 Scalar::max(self, 0)
1233 }
1234
1235 fn not(self) -> Self {
1236 i8::from(self == 0)
1237 }
1238
1239 fn nonzero(self) -> Self {
1240 i8::from(self != 0)
1241 }
1242
1243 fn add(self, rhs: Self) -> Self {
1244 self + rhs
1245 }
1246
1247 fn sub(self, rhs: Self) -> Self {
1248 self - rhs
1249 }
1250
1251 fn mul(self, rhs: Self) -> Self {
1252 self * rhs
1253 }
1254
1255 fn div(self, rhs: Self) -> Self {
1256 self / rhs
1257 }
1258
1259 fn pow(self, rhs: Self) -> Self {
1260 if rhs >= 0 {
1261 return self.pow(rhs as u32);
1262 }
1263 if self == 1 {
1264 return 1;
1265 }
1266 if self == -1 {
1267 return if rhs % 2 == 0 { 1 } else { -1 };
1268 }
1269 0
1270 }
1271
1272 fn mod_(self, rhs: Self) -> Self {
1273 self % rhs
1274 }
1275
1276 fn cmplt(self, rhs: Self) -> bool {
1277 self < rhs
1278 }
1279
1280 fn cmpgt(self, rhs: Self) -> bool {
1281 self > rhs
1282 }
1283
1284 fn noteq(self, rhs: Self) -> bool {
1285 self != rhs
1286 }
1287
1288 fn or(self, rhs: Self) -> bool {
1289 self != 0 || rhs != 0
1290 }
1291
1292 fn bitxor(self, rhs: Self) -> Self {
1293 self ^ rhs
1294 }
1295
1296 fn bitor(self, rhs: Self) -> Self {
1297 self | rhs
1298 }
1299
1300 fn bitand(self, rhs: Self) -> Self {
1301 self & rhs
1302 }
1303
1304 fn bitshiftleft(self, rhs: Self) -> Self {
1305 self.wrapping_shl(rhs as u32)
1306 }
1307
1308 fn bitshiftright(self, rhs: Self) -> Self {
1309 self.wrapping_shr(rhs as u32)
1310 }
1311
1312 fn and(self, rhs: Self) -> bool {
1313 self != 0 && rhs != 0
1314 }
1315
1316 fn max(self, rhs: Self) -> Self {
1317 <i8 as Ord>::max(self, rhs)
1318 }
1319
1320 fn max_value() -> Self {
1321 i8::MAX
1322 }
1323
1324 fn min_value() -> Self {
1325 i8::MIN
1326 }
1327
1328 fn is_equal(self, rhs: Self) -> bool {
1329 self == rhs
1330 }
1331
1332 fn epsilon() -> Self {
1333 0
1334 }
1335}
1336
1337impl Scalar for i16 {
1338 #[allow(clippy::cast_possible_truncation)]
1339 fn from_bf16(t: bf16) -> Self {
1340 t.to_f32() as i16
1341 }
1342
1343 #[allow(clippy::cast_possible_truncation)]
1344 fn from_f16(t: f16) -> Self {
1345 t.to_f32() as i16
1346 }
1347
1348 #[allow(clippy::cast_possible_truncation)]
1349 fn from_f32(t: f32) -> Self {
1350 t as i16
1351 }
1352
1353 #[allow(clippy::cast_possible_truncation)]
1354 fn from_f64(t: f64) -> Self {
1355 t as i16
1356 }
1357
1358 fn from_u8(t: u8) -> Self {
1359 t.into()
1360 }
1361
1362 #[allow(clippy::cast_possible_truncation)]
1363 #[allow(clippy::cast_possible_wrap)]
1364 fn from_u16(t: u16) -> Self {
1365 t as i16
1366 }
1367
1368 #[allow(clippy::cast_possible_truncation)]
1369 fn from_u32(t: u32) -> Self {
1370 t as i16
1371 }
1372
1373 fn from_u64(t: u64) -> Self {
1374 t.try_into().unwrap()
1375 }
1376
1377 fn from_i8(t: i8) -> Self {
1378 t.into()
1379 }
1380
1381 fn from_i16(t: i16) -> Self {
1382 t
1383 }
1384
1385 #[allow(clippy::cast_possible_truncation)]
1386 fn from_i32(t: i32) -> Self {
1387 t as i16
1388 }
1389
1390 #[allow(clippy::cast_possible_truncation)]
1391 fn from_i64(t: i64) -> Self {
1392 t as i16
1393 }
1394
1395 fn from_bool(t: bool) -> Self {
1396 t.into()
1397 }
1398
1399 fn from_le_bytes(bytes: &[u8]) -> Self {
1400 i16::from_le_bytes([bytes[0], bytes[1]])
1401 }
1402
1403 fn to_ne_bytes(&self) -> &[u8] {
1404 let i: *const Self = self;
1405 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1406 }
1407
1408 fn dtype() -> DType {
1409 DType::I16
1410 }
1411
1412 fn zero() -> Self {
1413 0
1414 }
1415
1416 fn one() -> Self {
1417 1
1418 }
1419
1420 fn abs(self) -> Self {
1421 self.abs()
1422 }
1423
1424 fn neg(self) -> Self {
1425 -self
1426 }
1427
1428 fn exp2(self) -> Self {
1429 2i32.pow(self as u32) as i16
1430 }
1431
1432 fn log2(self) -> Self {
1433 f64::from(self).log2() as i16
1434 }
1435
1436 fn relu(self) -> Self {
1437 Scalar::max(self, 0)
1438 }
1439
1440 fn not(self) -> Self {
1441 i16::from(self == 0)
1442 }
1443
1444 fn nonzero(self) -> Self {
1445 i16::from(self != 0)
1446 }
1447
1448 fn add(self, rhs: Self) -> Self {
1449 self + rhs
1450 }
1451
1452 fn sub(self, rhs: Self) -> Self {
1453 self - rhs
1454 }
1455
1456 fn mul(self, rhs: Self) -> Self {
1457 self * rhs
1458 }
1459
1460 fn div(self, rhs: Self) -> Self {
1461 self / rhs
1462 }
1463
1464 fn pow(self, rhs: Self) -> Self {
1465 if rhs >= 0 {
1466 return self.pow(rhs as u32);
1467 }
1468 if self == 1 {
1469 return 1;
1470 }
1471 if self == -1 {
1472 return if rhs % 2 == 0 { 1 } else { -1 };
1473 }
1474 0
1475 }
1476
1477 fn mod_(self, rhs: Self) -> Self {
1478 self % rhs
1479 }
1480
1481 fn cmplt(self, rhs: Self) -> bool {
1482 self < rhs
1483 }
1484
1485 fn cmpgt(self, rhs: Self) -> bool {
1486 self > rhs
1487 }
1488
1489 fn noteq(self, rhs: Self) -> bool {
1490 self != rhs
1491 }
1492
1493 fn or(self, rhs: Self) -> bool {
1494 self != 0 || rhs != 0
1495 }
1496
1497 fn bitxor(self, rhs: Self) -> Self {
1498 self ^ rhs
1499 }
1500
1501 fn bitor(self, rhs: Self) -> Self {
1502 self | rhs
1503 }
1504
1505 fn bitand(self, rhs: Self) -> Self {
1506 self & rhs
1507 }
1508
1509 fn bitshiftleft(self, rhs: Self) -> Self {
1510 self.wrapping_shl(rhs as u32)
1511 }
1512
1513 fn bitshiftright(self, rhs: Self) -> Self {
1514 self.wrapping_shr(rhs as u32)
1515 }
1516
1517 fn and(self, rhs: Self) -> bool {
1518 self != 0 && rhs != 0
1519 }
1520
1521 fn max(self, rhs: Self) -> Self {
1522 Ord::max(self, rhs)
1523 }
1524
1525 fn max_value() -> Self {
1526 i16::MAX
1527 }
1528
1529 fn min_value() -> Self {
1530 i16::MIN
1531 }
1532
1533 fn is_equal(self, rhs: Self) -> bool {
1534 self == rhs
1535 }
1536
1537 fn epsilon() -> Self {
1538 0
1539 }
1540}
1541
1542impl Scalar for i32 {
1543 fn from_bf16(t: bf16) -> Self {
1544 t.to_f32() as i32
1545 }
1546
1547 fn from_f16(t: f16) -> Self {
1548 t.to_f32() as i32
1549 }
1550
1551 #[allow(clippy::cast_possible_truncation)]
1552 fn from_f32(t: f32) -> Self {
1553 t as i32
1554 }
1555
1556 #[allow(clippy::cast_possible_truncation)]
1557 fn from_f64(t: f64) -> Self {
1558 t as i32
1559 }
1560
1561 fn from_u8(t: u8) -> Self {
1562 t.into()
1563 }
1564
1565 fn from_u16(t: u16) -> Self {
1566 t.into()
1567 }
1568
1569 fn from_u32(t: u32) -> Self {
1570 i32::try_from(t).unwrap()
1571 }
1572
1573 fn from_u64(t: u64) -> Self {
1574 t.try_into().unwrap()
1575 }
1576
1577 fn from_i8(t: i8) -> Self {
1578 t.into()
1579 }
1580
1581 fn from_i16(t: i16) -> Self {
1582 t.into()
1583 }
1584
1585 fn from_i32(t: i32) -> Self {
1586 t
1587 }
1588
1589 #[allow(clippy::cast_possible_truncation)]
1590 fn from_i64(t: i64) -> Self {
1591 t as i32
1592 }
1593
1594 fn from_bool(t: bool) -> Self {
1595 t.into()
1596 }
1597
1598 fn from_le_bytes(bytes: &[u8]) -> Self {
1599 i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
1600 }
1601
1602 fn to_ne_bytes(&self) -> &[u8] {
1603 let i: *const i32 = self;
1604 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<i32>()) }
1605 }
1606
1607 fn dtype() -> DType {
1608 DType::I32
1609 }
1610
1611 fn zero() -> Self {
1612 0
1613 }
1614
1615 fn one() -> Self {
1616 1
1617 }
1618
1619 fn abs(self) -> Self {
1620 self.abs()
1621 }
1622
1623 fn neg(self) -> Self {
1624 -self
1625 }
1626
1627 fn exp2(self) -> Self {
1628 2i32.pow(self as u32)
1629 }
1630
1631 fn log2(self) -> Self {
1632 f64::from(self).log2() as i32
1633 }
1634
1635 fn relu(self) -> Self {
1636 Scalar::max(self, 0)
1637 }
1638
1639 fn not(self) -> Self {
1640 i32::from(self == 0)
1641 }
1642
1643 fn nonzero(self) -> Self {
1644 i32::from(self != 0)
1645 }
1646
1647 fn add(self, rhs: Self) -> Self {
1648 self + rhs
1649 }
1650
1651 fn sub(self, rhs: Self) -> Self {
1652 self - rhs
1653 }
1654
1655 fn mul(self, rhs: Self) -> Self {
1656 self * rhs
1657 }
1658
1659 fn div(self, rhs: Self) -> Self {
1660 self / rhs
1661 }
1662
1663 fn pow(self, rhs: Self) -> Self {
1664 i32::pow(self, u32::try_from(rhs).unwrap())
1665 }
1666
1667 fn mod_(self, rhs: Self) -> Self {
1668 self % rhs
1669 }
1670
1671 fn cmplt(self, rhs: Self) -> bool {
1672 self < rhs
1673 }
1674
1675 fn cmpgt(self, rhs: Self) -> bool {
1676 self > rhs
1677 }
1678
1679 fn noteq(self, rhs: Self) -> bool {
1680 self != rhs
1681 }
1682
1683 fn or(self, rhs: Self) -> bool {
1684 self != 0 || rhs != 0
1685 }
1686
1687 fn bitxor(self, rhs: Self) -> Self {
1688 self ^ rhs
1689 }
1690
1691 fn bitor(self, rhs: Self) -> Self {
1692 self | rhs
1693 }
1694
1695 fn bitand(self, rhs: Self) -> Self {
1696 self & rhs
1697 }
1698
1699 fn bitshiftleft(self, rhs: Self) -> Self {
1700 self.wrapping_shl(rhs as u32)
1701 }
1702
1703 fn bitshiftright(self, rhs: Self) -> Self {
1704 self.wrapping_shr(rhs as u32)
1705 }
1706
1707 fn and(self, rhs: Self) -> bool {
1708 self != 0 && rhs != 0
1709 }
1710
1711 fn max(self, rhs: Self) -> Self {
1712 <i32 as Ord>::max(self, rhs)
1713 }
1714
1715 fn max_value() -> Self {
1716 i32::MAX
1717 }
1718
1719 fn min_value() -> Self {
1720 i32::MIN
1721 }
1722
1723 fn is_equal(self, rhs: Self) -> bool {
1724 self == rhs
1725 }
1726
1727 fn epsilon() -> Self {
1728 0
1729 }
1730}
1731
1732impl Scalar for i64 {
1733 fn from_bf16(t: bf16) -> Self {
1734 t.to_f32() as i64
1735 }
1736
1737 fn from_f16(t: f16) -> Self {
1738 t.to_f32() as i64
1739 }
1740
1741 #[allow(clippy::cast_possible_truncation)]
1742 fn from_f32(t: f32) -> Self {
1743 t as Self
1744 }
1745
1746 #[allow(clippy::cast_possible_truncation)]
1747 fn from_f64(t: f64) -> Self {
1748 t as Self
1749 }
1750
1751 fn from_u8(t: u8) -> Self {
1752 t.into()
1753 }
1754
1755 fn from_u16(t: u16) -> Self {
1756 t.into()
1757 }
1758
1759 fn from_u32(t: u32) -> Self {
1760 t.into()
1761 }
1762
1763 fn from_u64(t: u64) -> Self {
1764 t.try_into().unwrap()
1765 }
1766
1767 fn from_i8(t: i8) -> Self {
1768 t.into()
1769 }
1770
1771 fn from_i16(t: i16) -> Self {
1772 t.into()
1773 }
1774
1775 fn from_i32(t: i32) -> Self {
1776 t.into()
1777 }
1778
1779 fn from_i64(t: i64) -> Self {
1780 t
1781 }
1782
1783 fn from_bool(t: bool) -> Self {
1784 t.into()
1785 }
1786
1787 fn from_le_bytes(bytes: &[u8]) -> Self {
1788 i64::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
1789 }
1790
1791 fn to_ne_bytes(&self) -> &[u8] {
1792 let i: *const Self = self;
1793 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1794 }
1795
1796 fn dtype() -> DType {
1797 DType::I64
1798 }
1799
1800 fn zero() -> Self {
1801 0
1802 }
1803
1804 fn one() -> Self {
1805 1
1806 }
1807
1808 fn abs(self) -> Self {
1809 self.abs()
1810 }
1811
1812 fn neg(self) -> Self {
1813 -self
1814 }
1815
1816 fn exp2(self) -> Self {
1817 2i64.pow(self as u32)
1818 }
1819
1820 fn log2(self) -> Self {
1821 self as f64 as i64
1822 }
1823
1824 fn relu(self) -> Self {
1825 Scalar::max(self, 0)
1826 }
1827
1828 fn not(self) -> Self {
1829 i64::from(self == 0)
1830 }
1831
1832 fn nonzero(self) -> Self {
1833 i64::from(self != 0)
1834 }
1835
1836 fn add(self, rhs: Self) -> Self {
1837 self + rhs
1838 }
1839
1840 fn sub(self, rhs: Self) -> Self {
1841 self - rhs
1842 }
1843
1844 fn mul(self, rhs: Self) -> Self {
1845 self * rhs
1846 }
1847
1848 fn div(self, rhs: Self) -> Self {
1849 self / rhs
1850 }
1851
1852 fn pow(self, rhs: Self) -> Self {
1853 i64::pow(self, u32::try_from(rhs).unwrap())
1854 }
1855
1856 fn mod_(self, rhs: Self) -> Self {
1857 self % rhs
1858 }
1859
1860 fn cmplt(self, rhs: Self) -> bool {
1861 self < rhs
1862 }
1863
1864 fn cmpgt(self, rhs: Self) -> bool {
1865 self > rhs
1866 }
1867
1868 fn noteq(self, rhs: Self) -> bool {
1869 self != rhs
1870 }
1871
1872 fn or(self, rhs: Self) -> bool {
1873 self != 0 || rhs != 0
1874 }
1875
1876 fn bitxor(self, rhs: Self) -> Self {
1877 self ^ rhs
1878 }
1879
1880 fn bitor(self, rhs: Self) -> Self {
1881 self | rhs
1882 }
1883
1884 fn bitand(self, rhs: Self) -> Self {
1885 self & rhs
1886 }
1887
1888 fn bitshiftleft(self, rhs: Self) -> Self {
1889 self.wrapping_shl(rhs as u32)
1890 }
1891
1892 fn bitshiftright(self, rhs: Self) -> Self {
1893 self.wrapping_shr(rhs as u32)
1894 }
1895
1896 fn and(self, rhs: Self) -> bool {
1897 self != 0 && rhs != 0
1898 }
1899
1900 fn max(self, rhs: Self) -> Self {
1901 <i64 as Ord>::max(self, rhs)
1902 }
1903
1904 fn max_value() -> Self {
1905 Self::MAX
1906 }
1907
1908 fn min_value() -> Self {
1909 Self::MIN
1910 }
1911
1912 fn is_equal(self, rhs: Self) -> bool {
1913 self == rhs
1914 }
1915
1916 fn epsilon() -> Self {
1917 0
1918 }
1919}
1920
1921impl Scalar for u8 {
1922 fn from_bf16(t: bf16) -> Self {
1923 t.to_f32() as u32 as u8
1924 }
1925
1926 fn from_f16(t: f16) -> Self {
1927 t.to_f32() as u32 as u8
1928 }
1929
1930 fn from_f32(t: f32) -> Self {
1931 t as u32 as u8
1932 }
1933
1934 fn from_f64(t: f64) -> Self {
1935 t as u32 as u8
1936 }
1937
1938 fn from_u8(t: u8) -> Self {
1939 t
1940 }
1941
1942 fn from_u16(t: u16) -> Self {
1943 t.try_into().unwrap()
1944 }
1945
1946 fn from_u32(t: u32) -> Self {
1947 t.try_into().unwrap()
1948 }
1949
1950 fn from_u64(t: u64) -> Self {
1951 t.try_into().unwrap()
1952 }
1953
1954 fn from_i8(t: i8) -> Self {
1955 t.try_into().unwrap()
1956 }
1957
1958 fn from_i16(t: i16) -> Self {
1959 t.try_into().unwrap()
1960 }
1961
1962 fn from_i32(t: i32) -> Self {
1963 t.try_into().unwrap()
1964 }
1965
1966 fn from_i64(t: i64) -> Self {
1967 t.try_into().unwrap()
1968 }
1969
1970 fn from_bool(t: bool) -> Self {
1971 t.into()
1972 }
1973
1974 fn from_le_bytes(bytes: &[u8]) -> Self {
1975 u8::from_le_bytes([bytes[0]])
1976 }
1977
1978 fn to_ne_bytes(&self) -> &[u8] {
1979 let i: *const Self = self;
1980 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
1981 }
1982
1983 fn dtype() -> DType {
1984 DType::U8
1985 }
1986
1987 fn zero() -> Self {
1988 0
1989 }
1990
1991 fn one() -> Self {
1992 1
1993 }
1994
1995 fn abs(self) -> Self {
1996 self
1997 }
1998
1999 fn neg(self) -> Self {
2000 self.wrapping_neg()
2001 }
2002
2003 fn exp2(self) -> Self {
2004 if self <= 31 { 2u32.pow(self as u32) as u8 } else { 255 }
2005 }
2006
2007 fn log2(self) -> Self {
2008 self.ilog2() as u8
2009 }
2010
2011 fn relu(self) -> Self {
2012 self
2013 }
2014
2015 fn not(self) -> Self {
2016 u8::from(self == 0)
2017 }
2018
2019 fn nonzero(self) -> Self {
2020 u8::from(self != 0)
2021 }
2022
2023 fn add(self, rhs: Self) -> Self {
2024 self + rhs
2025 }
2026
2027 fn sub(self, rhs: Self) -> Self {
2028 self - rhs
2029 }
2030
2031 fn mul(self, rhs: Self) -> Self {
2032 self * rhs
2033 }
2034
2035 fn div(self, rhs: Self) -> Self {
2036 self / rhs
2037 }
2038
2039 fn pow(self, rhs: Self) -> Self {
2040 Self::pow(self, u32::from(rhs))
2041 }
2042
2043 fn mod_(self, rhs: Self) -> Self {
2044 self % rhs
2045 }
2046
2047 fn cmplt(self, rhs: Self) -> bool {
2048 self < rhs
2049 }
2050
2051 fn cmpgt(self, rhs: Self) -> bool {
2052 self > rhs
2053 }
2054
2055 fn noteq(self, rhs: Self) -> bool {
2056 self != rhs
2057 }
2058
2059 fn or(self, rhs: Self) -> bool {
2060 self != 0 || rhs != 0
2061 }
2062
2063 fn bitxor(self, rhs: Self) -> Self {
2064 self ^ rhs
2065 }
2066
2067 fn bitor(self, rhs: Self) -> Self {
2068 self | rhs
2069 }
2070
2071 fn bitand(self, rhs: Self) -> Self {
2072 self & rhs
2073 }
2074
2075 fn bitshiftleft(self, rhs: Self) -> Self {
2076 self.wrapping_shl(rhs as u32)
2077 }
2078
2079 fn bitshiftright(self, rhs: Self) -> Self {
2080 self.wrapping_shr(rhs as u32)
2081 }
2082
2083 fn and(self, rhs: Self) -> bool {
2084 self != 0 && rhs != 0
2085 }
2086
2087 fn max(self, rhs: Self) -> Self {
2088 Ord::max(self, rhs)
2089 }
2090
2091 fn max_value() -> Self {
2092 u8::MAX
2093 }
2094
2095 fn min_value() -> Self {
2096 u8::MIN
2097 }
2098
2099 fn is_equal(self, rhs: Self) -> bool {
2100 self == rhs
2101 }
2102
2103 fn epsilon() -> Self {
2104 0
2105 }
2106}
2107
2108impl Scalar for u16 {
2109 fn from_bf16(t: bf16) -> Self {
2110 t.to_f32() as u32 as u16
2111 }
2112
2113 fn from_f16(t: f16) -> Self {
2114 t.to_f32() as u32 as u16
2115 }
2116
2117 fn from_f32(t: f32) -> Self {
2118 t as u32 as u16
2119 }
2120
2121 fn from_f64(t: f64) -> Self {
2122 t as u32 as u16
2123 }
2124
2125 fn from_u8(t: u8) -> Self {
2126 t.into()
2127 }
2128
2129 fn from_u16(t: u16) -> Self {
2130 t
2131 }
2132
2133 fn from_u32(t: u32) -> Self {
2134 t.try_into().unwrap()
2135 }
2136
2137 fn from_u64(t: u64) -> Self {
2138 t.try_into().unwrap()
2139 }
2140
2141 fn from_i8(t: i8) -> Self {
2142 t.try_into().unwrap()
2143 }
2144
2145 fn from_i16(t: i16) -> Self {
2146 t.try_into().unwrap()
2147 }
2148
2149 fn from_i32(t: i32) -> Self {
2150 t.try_into().unwrap()
2151 }
2152
2153 fn from_i64(t: i64) -> Self {
2154 t.try_into().unwrap()
2155 }
2156
2157 fn from_bool(t: bool) -> Self {
2158 t.into()
2159 }
2160
2161 fn from_le_bytes(bytes: &[u8]) -> Self {
2162 Self::from_le_bytes([bytes[0], bytes[1]])
2163 }
2164
2165 fn to_ne_bytes(&self) -> &[u8] {
2166 let i: *const Self = self;
2167 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2168 }
2169
2170 fn dtype() -> DType {
2171 DType::U16
2172 }
2173
2174 fn zero() -> Self {
2175 0
2176 }
2177
2178 fn one() -> Self {
2179 1
2180 }
2181
2182 fn abs(self) -> Self {
2183 self
2184 }
2185
2186 fn neg(self) -> Self {
2187 self.wrapping_neg()
2188 }
2189
2190 fn exp2(self) -> Self {
2191 if self <= 31 { 2u32.pow(self as u32) as u16 } else { 65535 }
2192 }
2193
2194 fn log2(self) -> Self {
2195 self.ilog2() as u16
2196 }
2197
2198 fn relu(self) -> Self {
2199 self
2200 }
2201
2202 fn not(self) -> Self {
2203 u16::from(self == 0)
2204 }
2205
2206 fn nonzero(self) -> Self {
2207 u16::from(self != 0)
2208 }
2209
2210 fn add(self, rhs: Self) -> Self {
2211 self + rhs
2212 }
2213
2214 fn sub(self, rhs: Self) -> Self {
2215 self - rhs
2216 }
2217
2218 fn mul(self, rhs: Self) -> Self {
2219 self * rhs
2220 }
2221
2222 fn div(self, rhs: Self) -> Self {
2223 self / rhs
2224 }
2225
2226 fn pow(self, rhs: Self) -> Self {
2227 Self::pow(self, u32::from(rhs)) as u16
2228 }
2229
2230 fn mod_(self, rhs: Self) -> Self {
2231 self % rhs
2232 }
2233
2234 fn cmplt(self, rhs: Self) -> bool {
2235 self < rhs
2236 }
2237
2238 fn cmpgt(self, rhs: Self) -> bool {
2239 self > rhs
2240 }
2241
2242 fn noteq(self, rhs: Self) -> bool {
2243 self != rhs
2244 }
2245
2246 fn or(self, rhs: Self) -> bool {
2247 self != 0 || rhs != 0
2248 }
2249
2250 fn bitxor(self, rhs: Self) -> Self {
2251 self ^ rhs
2252 }
2253
2254 fn bitor(self, rhs: Self) -> Self {
2255 self | rhs
2256 }
2257
2258 fn bitand(self, rhs: Self) -> Self {
2259 self & rhs
2260 }
2261
2262 fn bitshiftleft(self, rhs: Self) -> Self {
2263 self.wrapping_shl(rhs as u32)
2264 }
2265
2266 fn bitshiftright(self, rhs: Self) -> Self {
2267 self.wrapping_shr(rhs as u32)
2268 }
2269
2270 fn and(self, rhs: Self) -> bool {
2271 self != 0 && rhs != 0
2272 }
2273
2274 fn max(self, rhs: Self) -> Self {
2275 Ord::max(self, rhs)
2276 }
2277
2278 fn max_value() -> Self {
2279 Self::MAX
2280 }
2281
2282 fn min_value() -> Self {
2283 Self::MIN
2284 }
2285
2286 fn is_equal(self, rhs: Self) -> bool {
2287 self == rhs
2288 }
2289
2290 fn epsilon() -> Self {
2291 0
2292 }
2293
2294 fn cast<T: Scalar>(self) -> T {
2295 use core::mem::transmute_copy as t;
2296 unsafe {
2297 match Self::dtype() {
2298 DType::BF16 => T::from_bf16(t(&self)),
2299 DType::F16 => T::from_f16(t(&self)),
2300 DType::F32 => T::from_f32(t(&self)),
2301 DType::F64 => T::from_f64(t(&self)),
2302 DType::U8 => T::from_u8(t(&self)),
2303 DType::U16 => T::from_u16(t(&self)),
2304 DType::U32 => T::from_u32(t(&self)),
2305 DType::U64 => T::from_u64(t(&self)),
2306 DType::I8 => T::from_i8(t(&self)),
2307 DType::I16 => T::from_i16(t(&self)),
2308 DType::I32 => T::from_i32(t(&self)),
2309 DType::I64 => T::from_i64(t(&self)),
2310 DType::Bool => T::from_bool(t(&self)),
2311 }
2312 }
2313 }
2314}
2315
2316impl Scalar for u32 {
2317 fn from_bf16(t: bf16) -> Self {
2318 t.to_f32() as u32
2319 }
2320
2321 fn from_f16(t: f16) -> Self {
2322 t.to_f32() as Self
2323 }
2324
2325 fn from_f32(t: f32) -> Self {
2326 t as Self
2327 }
2328
2329 fn from_f64(t: f64) -> Self {
2330 t as Self
2331 }
2332
2333 fn from_u8(t: u8) -> Self {
2334 t.into()
2335 }
2336
2337 fn from_u16(t: u16) -> Self {
2338 t.into()
2339 }
2340
2341 fn from_u32(t: u32) -> Self {
2342 t
2343 }
2344
2345 fn from_u64(t: u64) -> Self {
2346 t.try_into().unwrap()
2347 }
2348
2349 fn from_i8(t: i8) -> Self {
2350 t.try_into().unwrap()
2351 }
2352
2353 fn from_i16(t: i16) -> Self {
2354 t.try_into().unwrap()
2355 }
2356
2357 fn from_i32(t: i32) -> Self {
2358 t.try_into().unwrap()
2359 }
2360
2361 fn from_i64(t: i64) -> Self {
2362 t.try_into().unwrap()
2363 }
2364
2365 fn from_bool(t: bool) -> Self {
2366 t.into()
2367 }
2368
2369 fn from_le_bytes(bytes: &[u8]) -> Self {
2370 Self::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
2371 }
2372
2373 fn to_ne_bytes(&self) -> &[u8] {
2374 let i: *const Self = self;
2375 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2376 }
2377
2378 fn dtype() -> DType {
2379 DType::U32
2380 }
2381
2382 fn zero() -> Self {
2383 0
2384 }
2385
2386 fn one() -> Self {
2387 1
2388 }
2389
2390 fn abs(self) -> Self {
2391 self
2392 }
2393
2394 fn neg(self) -> Self {
2395 self.wrapping_neg()
2396 }
2397
2398 fn exp2(self) -> Self {
2399 if self <= 31 { 2u32.pow(self) } else { u32::MAX }
2400 }
2401
2402 fn log2(self) -> Self {
2403 self.ilog2()
2404 }
2405
2406 fn relu(self) -> Self {
2407 self
2408 }
2409
2410 fn not(self) -> Self {
2411 u32::from(self == 0)
2412 }
2413
2414 fn nonzero(self) -> Self {
2415 u32::from(self != 0)
2416 }
2417
2418 fn add(self, rhs: Self) -> Self {
2419 self.wrapping_add(rhs)
2420 }
2421
2422 fn sub(self, rhs: Self) -> Self {
2423 self.wrapping_sub(rhs)
2424 }
2425
2426 fn mul(self, rhs: Self) -> Self {
2427 self * rhs
2428 }
2429
2430 fn div(self, rhs: Self) -> Self {
2431 self / rhs
2432 }
2433
2434 fn pow(self, rhs: Self) -> Self {
2435 u32::pow(self, rhs)
2436 }
2437
2438 fn mod_(self, rhs: Self) -> Self {
2439 self % rhs
2440 }
2441
2442 fn cmplt(self, rhs: Self) -> bool {
2443 self < rhs
2444 }
2445
2446 fn cmpgt(self, rhs: Self) -> bool {
2447 self > rhs
2448 }
2449
2450 fn noteq(self, rhs: Self) -> bool {
2451 self != rhs
2452 }
2453
2454 fn or(self, rhs: Self) -> bool {
2455 self != 0 || rhs != 0
2456 }
2457
2458 fn bitxor(self, rhs: Self) -> Self {
2459 self ^ rhs
2460 }
2461
2462 fn bitor(self, rhs: Self) -> Self {
2463 self | rhs
2464 }
2465
2466 fn bitand(self, rhs: Self) -> Self {
2467 self & rhs
2468 }
2469
2470 fn bitshiftleft(self, rhs: Self) -> Self {
2471 self << rhs
2472 }
2473
2474 fn bitshiftright(self, rhs: Self) -> Self {
2475 self >> rhs
2476 }
2477
2478 fn and(self, rhs: Self) -> bool {
2479 self != 0 && rhs != 0
2480 }
2481
2482 fn max(self, rhs: Self) -> Self {
2483 Ord::max(self, rhs)
2484 }
2485
2486 fn max_value() -> Self {
2487 Self::MAX
2488 }
2489
2490 fn min_value() -> Self {
2491 Self::MIN
2492 }
2493
2494 fn is_equal(self, rhs: Self) -> bool {
2495 self == rhs
2496 }
2497
2498 fn epsilon() -> Self {
2499 0
2500 }
2501
2502 fn cast<T: Scalar>(self) -> T {
2503 use core::mem::transmute_copy as t;
2504 unsafe {
2505 match Self::dtype() {
2506 DType::BF16 => T::from_bf16(t(&self)),
2507 DType::F16 => T::from_f16(t(&self)),
2508 DType::F32 => T::from_f32(t(&self)),
2509 DType::F64 => T::from_f64(t(&self)),
2510 DType::U8 => T::from_u8(t(&self)),
2511 DType::U16 => T::from_u16(t(&self)),
2512 DType::U32 => T::from_u32(t(&self)),
2513 DType::U64 => T::from_u64(t(&self)),
2514 DType::I8 => T::from_i8(t(&self)),
2515 DType::I16 => T::from_i16(t(&self)),
2516 DType::I32 => T::from_i32(t(&self)),
2517 DType::I64 => T::from_i64(t(&self)),
2518 DType::Bool => T::from_bool(t(&self)),
2519 }
2520 }
2521 }
2522}
2523
2524impl Scalar for u64 {
2525 fn from_bf16(t: bf16) -> Self {
2526 t.to_f32() as u64
2527 }
2528
2529 fn from_f16(t: f16) -> Self {
2530 t.to_f32() as Self
2531 }
2532
2533 fn from_f32(t: f32) -> Self {
2534 t as u64
2535 }
2536
2537 fn from_f64(t: f64) -> Self {
2538 t as Self
2539 }
2540
2541 fn from_u8(t: u8) -> Self {
2542 t.into()
2543 }
2544
2545 fn from_u16(t: u16) -> Self {
2546 t.into()
2547 }
2548
2549 fn from_u32(t: u32) -> Self {
2550 t.into()
2551 }
2552
2553 fn from_u64(t: u64) -> Self {
2554 t
2555 }
2556
2557 fn from_i8(t: i8) -> Self {
2558 t.try_into().unwrap()
2559 }
2560
2561 fn from_i16(t: i16) -> Self {
2562 t.try_into().unwrap()
2563 }
2564
2565 fn from_i32(t: i32) -> Self {
2566 t.try_into().unwrap()
2567 }
2568
2569 fn from_i64(t: i64) -> Self {
2570 t.try_into().unwrap()
2571 }
2572
2573 fn from_bool(t: bool) -> Self {
2574 t.into()
2575 }
2576
2577 fn from_le_bytes(bytes: &[u8]) -> Self {
2578 Self::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
2579 }
2580
2581 fn to_ne_bytes(&self) -> &[u8] {
2582 let i: *const Self = self;
2583 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2584 }
2585
2586 fn dtype() -> DType {
2587 DType::U64
2588 }
2589
2590 fn zero() -> Self {
2591 0
2592 }
2593
2594 fn one() -> Self {
2595 1
2596 }
2597
2598 fn abs(self) -> Self {
2599 self
2600 }
2601
2602 fn neg(self) -> Self {
2603 self.wrapping_neg()
2604 }
2605
2606 fn exp2(self) -> Self {
2607 if self <= 63 { 2u64.pow(self as u32) } else { u64::MAX }
2608 }
2609
2610 fn log2(self) -> Self {
2611 self.ilog2() as u64
2612 }
2613
2614 fn relu(self) -> Self {
2615 self
2616 }
2617
2618 fn not(self) -> Self {
2619 u64::from(self == 0)
2620 }
2621
2622 fn nonzero(self) -> Self {
2623 u64::from(self != 0)
2624 }
2625
2626 fn add(self, rhs: Self) -> Self {
2627 self + rhs
2628 }
2629
2630 fn sub(self, rhs: Self) -> Self {
2631 self.wrapping_sub(rhs)
2632 }
2633
2634 fn mul(self, rhs: Self) -> Self {
2635 self * rhs
2636 }
2637
2638 fn div(self, rhs: Self) -> Self {
2639 self / rhs
2640 }
2641
2642 fn pow(self, rhs: Self) -> Self {
2643 i64::pow(self as i64, u32::try_from(rhs).unwrap()) as u64
2644 }
2645
2646 fn mod_(self, rhs: Self) -> Self {
2647 self % rhs
2648 }
2649
2650 fn cmplt(self, rhs: Self) -> bool {
2651 self < rhs
2652 }
2653
2654 fn cmpgt(self, rhs: Self) -> bool {
2655 self > rhs
2656 }
2657
2658 fn noteq(self, rhs: Self) -> bool {
2659 self != rhs
2660 }
2661
2662 fn or(self, rhs: Self) -> bool {
2663 self != 0 || rhs != 0
2664 }
2665
2666 fn bitxor(self, rhs: Self) -> Self {
2667 self ^ rhs
2668 }
2669
2670 fn bitor(self, rhs: Self) -> Self {
2671 self | rhs
2672 }
2673
2674 fn bitand(self, rhs: Self) -> Self {
2675 self & rhs
2676 }
2677
2678 fn bitshiftleft(self, rhs: Self) -> Self {
2679 self << rhs
2680 }
2681
2682 fn bitshiftright(self, rhs: Self) -> Self {
2683 self >> rhs
2684 }
2685
2686 fn and(self, rhs: Self) -> bool {
2687 self != 0 && rhs != 0
2688 }
2689
2690 fn max(self, rhs: Self) -> Self {
2691 Ord::max(self, rhs)
2692 }
2693
2694 fn max_value() -> Self {
2695 Self::MAX
2696 }
2697
2698 fn min_value() -> Self {
2699 Self::MIN
2700 }
2701
2702 fn is_equal(self, rhs: Self) -> bool {
2703 self == rhs
2704 }
2705
2706 fn epsilon() -> Self {
2707 0
2708 }
2709}
2710
2711impl Scalar for bool {
2712 fn from_bf16(t: bf16) -> Self {
2713 t != bf16::ZERO
2714 }
2715
2716 fn from_f16(t: f16) -> Self {
2717 t != f16::ZERO
2718 }
2719
2720 fn from_f32(t: f32) -> Self {
2721 t != 0.
2722 }
2723
2724 fn from_f64(t: f64) -> Self {
2725 t != 0.
2726 }
2727
2728 fn from_u8(t: u8) -> Self {
2729 t != 0
2730 }
2731
2732 fn from_u16(t: u16) -> Self {
2733 t != 0
2734 }
2735
2736 fn from_u32(t: u32) -> Self {
2737 t != 0
2738 }
2739
2740 fn from_u64(t: u64) -> Self {
2741 t != 0
2742 }
2743
2744 fn from_i8(t: i8) -> Self {
2745 t != 0
2746 }
2747
2748 fn from_i16(t: i16) -> Self {
2749 t != 0
2750 }
2751
2752 fn from_i32(t: i32) -> Self {
2753 t != 0
2754 }
2755
2756 fn from_i64(t: i64) -> Self {
2757 t != 0
2758 }
2759
2760 fn from_bool(t: bool) -> Self {
2761 t
2762 }
2763
2764 fn from_le_bytes(bytes: &[u8]) -> Self {
2765 bytes[0] != 0
2766 }
2767
2768 fn to_ne_bytes(&self) -> &[u8] {
2769 let i: *const Self = self;
2770 unsafe { std::slice::from_raw_parts(i.cast::<u8>(), std::mem::size_of::<Self>()) }
2771 }
2772
2773 fn dtype() -> DType {
2774 DType::Bool
2775 }
2776
2777 fn zero() -> Self {
2778 false
2779 }
2780
2781 fn one() -> Self {
2782 true
2783 }
2784
2785 fn abs(self) -> Self {
2786 self
2787 }
2788
2789 fn neg(self) -> Self {
2790 panic!()
2791 }
2792
2793 fn exp2(self) -> Self {
2794 panic!()
2795 }
2796
2797 fn log2(self) -> Self {
2798 panic!()
2799 }
2800
2801 fn relu(self) -> Self {
2802 panic!()
2803 }
2804
2805 fn not(self) -> Self {
2806 !self
2807 }
2808
2809 fn nonzero(self) -> Self {
2810 self
2811 }
2812
2813 fn add(self, rhs: Self) -> Self {
2814 self | rhs
2815 }
2816
2817 fn sub(self, rhs: Self) -> Self {
2818 let _ = rhs;
2819 panic!()
2820 }
2821
2822 fn mul(self, rhs: Self) -> Self {
2823 self & rhs
2824 }
2825
2826 fn div(self, rhs: Self) -> Self {
2827 let _ = rhs;
2828 panic!()
2829 }
2830
2831 fn pow(self, rhs: Self) -> Self {
2832 let _ = rhs;
2833 panic!()
2834 }
2835
2836 fn mod_(self, rhs: Self) -> Self {
2837 let _ = rhs;
2838 panic!()
2839 }
2840
2841 fn cmplt(self, rhs: Self) -> Self {
2842 !self & rhs
2843 }
2844
2845 fn cmpgt(self, rhs: Self) -> Self {
2846 self && !rhs
2847 }
2848
2849 fn noteq(self, rhs: Self) -> bool {
2850 self != rhs
2851 }
2852
2853 fn or(self, rhs: Self) -> Self {
2854 self || rhs
2855 }
2856
2857 fn bitxor(self, rhs: Self) -> Self {
2858 self ^ rhs
2859 }
2860
2861 fn bitor(self, rhs: Self) -> Self {
2862 self | rhs
2863 }
2864
2865 fn bitand(self, rhs: Self) -> Self {
2866 self & rhs
2867 }
2868
2869 fn bitshiftleft(self, _rhs: Self) -> Self {
2870 self
2871 }
2872
2873 fn bitshiftright(self, _rhs: Self) -> Self {
2874 false
2875 }
2876
2877 fn and(self, rhs: Self) -> bool {
2878 self && rhs
2879 }
2880
2881 fn max(self, rhs: Self) -> Self {
2882 <bool as Ord>::max(self, rhs)
2883 }
2884
2885 fn max_value() -> Self {
2886 true
2887 }
2888
2889 fn min_value() -> Self {
2890 false
2891 }
2892
2893 fn is_equal(self, rhs: Self) -> bool {
2894 self == rhs
2895 }
2896
2897 fn epsilon() -> Self {
2898 false
2899 }
2900
2901 fn cast<T: Scalar>(self) -> T {
2902 use core::mem::transmute_copy as t;
2903 unsafe {
2904 match Self::dtype() {
2905 DType::BF16 => T::from_bf16(t(&self)),
2906 DType::F16 => T::from_f16(t(&self)),
2907 DType::F32 => T::from_f32(t(&self)),
2908 DType::F64 => T::from_f64(t(&self)),
2909 DType::U8 => T::from_u8(t(&self)),
2910 DType::U16 => T::from_u16(t(&self)),
2911 DType::U32 => T::from_u32(t(&self)),
2912 DType::U64 => T::from_u64(t(&self)),
2913 DType::I8 => T::from_i8(t(&self)),
2914 DType::I16 => T::from_i16(t(&self)),
2915 DType::I32 => T::from_i32(t(&self)),
2916 DType::I64 => T::from_i64(t(&self)),
2917 DType::Bool => T::from_bool(t(&self)),
2918 }
2919 }
2920 }
2921}