1#[macro_export]
10macro_rules! impl_dot {
11 ($($t:ty),+) => {
12 $(
13 paste::paste! {
14 #[inline(always)]
15 fn [<dot_ $t>](&self, a: &[$t], b: &[$t]) -> f64 {
16 assert_eq!(a.len(), b.len(), "Vector lengths must match for dot product");
17 a.iter().zip(b.iter()).map(|(x, y)| (*x as f64) * (*y as f64)).sum()
18 }
19 }
20 )+
21 };
22}
23
24#[macro_export]
26macro_rules! impl_scal {
27 ($($t:ty),+) => {
28 $(
29 paste::paste! {
30 #[inline(always)]
31 fn [<scal_ $t>](&self, a: $t, x: &mut [$t]) {
32 for xi in x.iter_mut() {
33 *xi *= a;
34 }
35 }
36 }
37 )+
38 };
39}
40
41#[macro_export]
43macro_rules! impl_asum_signed {
44 ($($t:ty),+) => {
45 $(
46 paste::paste! {
47 #[inline(always)]
48 fn [<asum_ $t>](&self, x: &[$t]) -> $t {
49 if x.is_empty() {
50 return 0 as $t;
51 }
52 x.iter().map(|xi| (*xi).abs()).sum()
53 }
54 }
55 )+
56 };
57}
58
59#[macro_export]
61macro_rules! impl_asum_unsigned {
62 ($($t:ty),+) => {
63 $(
64 paste::paste! {
65 #[inline(always)]
66 fn [<asum_ $t>](&self, x: &[$t]) -> $t {
67 if x.is_empty() {
68 return 0 as $t;
69 }
70 x.iter().sum()
72 }
73 }
74 )+
75 };
76}
77
78#[macro_export]
80macro_rules! impl_asum_half {
81 () => {
82 #[inline(always)]
83 fn asum_f16(&self, x: &[half::f16]) -> f32 {
84 if x.is_empty() {
85 return 0.0f32;
86 }
87 x.iter().map(|xi| xi.to_f32().abs()).sum()
88 }
89
90 #[inline(always)]
91 fn asum_bf16(&self, x: &[half::bf16]) -> f32 {
92 if x.is_empty() {
93 return 0.0f32;
94 }
95 x.iter().map(|xi| xi.to_f32().abs()).sum()
96 }
97 };
98}
99
100#[macro_export]
105macro_rules! impl_gemm_sd {
106 (f32, f64) => {
107 #[inline(always)]
110 unsafe fn gemm_f32(
118 &self,
119 m: usize,
120 n: usize,
121 k: usize,
122 a: *const f32,
123 lda: usize,
124 b: *const f32,
125 ldb: usize,
126 c: *mut f32,
127 ldc: usize,
128 ) {
129 #[cfg(feature = "rayon")]
131 let parallelism = {
132 let num_threads = $crate::get_num_threads();
133 if num_threads > 1 {
134 gemm::Parallelism::Rayon(num_threads)
135 } else {
136 gemm::Parallelism::None
137 }
138 };
139 #[cfg(not(feature = "rayon"))]
140 let parallelism = gemm::Parallelism::None;
141
142 gemm::gemm(
148 m,
149 n,
150 k,
151 c,
152 1, ldc as isize, false, a, 1, lda as isize, b, 1, ldb as isize, 0.0f32, 1.0f32, false,
164 false,
165 false, parallelism,
167 );
168 }
169
170 #[inline(always)]
173 unsafe fn gemm_f64(
181 &self,
182 m: usize,
183 n: usize,
184 k: usize,
185 a: *const f64,
186 lda: usize,
187 b: *const f64,
188 ldb: usize,
189 c: *mut f64,
190 ldc: usize,
191 ) {
192 #[cfg(feature = "rayon")]
194 let parallelism = {
195 let num_threads = $crate::get_num_threads();
196 if num_threads > 1 {
197 gemm::Parallelism::Rayon(num_threads)
198 } else {
199 gemm::Parallelism::None
200 }
201 };
202 #[cfg(not(feature = "rayon"))]
203 let parallelism = gemm::Parallelism::None;
204
205 gemm::gemm(
211 m,
212 n,
213 k,
214 c,
215 1, ldc as isize, false, a, 1, lda as isize, b, 1, ldb as isize, 0.0f64, 1.0f64, false,
227 false,
228 false, parallelism,
230 );
231 }
232 };
233}
234
235#[macro_export]
237macro_rules! impl_gemm {
238 ($($t:ty),*) => {
239 $(
240 paste::paste! {
241 #[inline(always)]
252 unsafe fn [<gemm_ $t>](
253 &self,
254 m: usize,
255 n: usize,
256 k: usize,
257 a: *const $t,
258 lda: usize,
259 b: *const $t,
260 ldb: usize,
261 c: *mut $t,
262 ldc: usize,
263 ) {
264 for i in 0..m {
266 for j in 0..n {
267 let mut sum = 0 as $t;
268 for l in 0..k {
269 sum += *a.add(i * lda + l) * *b.add(l * ldb + j);
270 }
271 *c.add(i * ldc + j) = sum;
272 }
273 }
274 }
275 }
276 )+
277 };
278}
279
280#[macro_export]
282macro_rules! impl_gemm_half {
283 () => {
284 #[inline(always)]
296 unsafe fn gemm_f16(
304 &self,
305 m: usize,
306 n: usize,
307 k: usize,
308 a: *const half::f16,
309 lda: usize,
310 b: *const half::f16,
311 ldb: usize,
312 c: *mut half::f16,
313 ldc: usize,
314 ) {
315 #[cfg(feature = "rayon")]
317 let parallelism = {
318 let num_threads = $crate::get_num_threads();
319 if num_threads > 1 {
320 gemm::Parallelism::Rayon(num_threads)
321 } else {
322 gemm::Parallelism::None
323 }
324 };
325 #[cfg(not(feature = "rayon"))]
326 let parallelism = gemm::Parallelism::None;
327
328 gemm::gemm(
333 m,
334 n,
335 k,
336 c,
337 1, ldc as isize, false, a, 1, lda as isize, b, 1, ldb as isize, gemm::f16::ZERO, gemm::f16::ONE, false,
349 false,
350 false, parallelism,
352 );
353 }
354
355 #[inline(always)]
366 unsafe fn gemm_bf16(
367 &self,
368 m: usize,
369 n: usize,
370 k: usize,
371 a: *const half::bf16,
372 lda: usize,
373 b: *const half::bf16,
374 ldb: usize,
375 c: *mut half::bf16,
376 ldc: usize,
377 ) {
378 for i in 0..m {
379 for j in 0..n {
380 let mut sum = half::bf16::ZERO;
381 for l in 0..k {
382 let a_val = (*a.add(i * lda + l));
383 let b_val = (*b.add(l * ldb + j));
384 sum += a_val * b_val;
385 }
386 *c.add(i * ldc + j) = sum;
387 }
388 }
389 }
390 };
391}
392
393#[macro_export]
397macro_rules! impl_v_exp {
398 ($($t:ty),+) => {
399 $(
400 paste::paste! {
401 #[inline(always)]
402 fn [<v_exp_ $t>](&self, x: &[$t], out: &mut [$t]) {
403 assert_eq!(
404 x.len(),
405 out.len(),
406 "Input and output slices must have same length"
407 );
408 for (o, xi) in out.iter_mut().zip(x.iter()) {
409 *o = (*xi ).exp() ;
410 }
411 }
412 }
413 )+
414 };
415}
416
417#[macro_export]
419macro_rules! impl_v_exp_half {
420 () => {
421 #[inline(always)]
422 fn v_exp_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
423 assert_eq!(
424 x.len(),
425 out.len(),
426 "Input and output slices must have same length"
427 );
428 for (o, xi) in out.iter_mut().zip(x.iter()) {
429 *o = half::f16::from_f32(xi.to_f32().exp());
430 }
431 }
432
433 #[inline(always)]
434 fn v_exp_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
435 assert_eq!(
436 x.len(),
437 out.len(),
438 "Input and output slices must have same length"
439 );
440 for (o, xi) in out.iter_mut().zip(x.iter()) {
441 *o = half::bf16::from_f32(xi.to_f32().exp());
442 }
443 }
444 };
445}
446
447#[macro_export]
449macro_rules! impl_v_sin {
450 ($($t:ty),+) => {
451 $(
452 paste::paste! {
453 #[inline(always)]
454 fn [<v_sin_ $t>](&self, x: &[$t], out: &mut [$t]) {
455 assert_eq!(
456 x.len(),
457 out.len(),
458 "Input and output slices must have same length"
459 );
460 for (o, xi) in out.iter_mut().zip(x.iter()) {
461 *o = (*xi).sin() ;
462 }
463 }
464 }
465 )+
466 };
467}
468
469#[macro_export]
471macro_rules! impl_v_sin_half {
472 () => {
473 #[inline(always)]
474 fn v_sin_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
475 assert_eq!(
476 x.len(),
477 out.len(),
478 "Input and output slices must have same length"
479 );
480 for (o, xi) in out.iter_mut().zip(x.iter()) {
481 *o = half::f16::from_f32(xi.to_f32().sin());
482 }
483 }
484
485 #[inline(always)]
486 fn v_sin_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
487 assert_eq!(
488 x.len(),
489 out.len(),
490 "Input and output slices must have same length"
491 );
492 for (o, xi) in out.iter_mut().zip(x.iter()) {
493 *o = half::bf16::from_f32(xi.to_f32().sin());
494 }
495 }
496 };
497}
498
499#[macro_export]
501macro_rules! impl_v_cos {
502 ($($t:ty),+) => {
503 $(
504 paste::paste! {
505 #[inline(always)]
506 fn [<v_cos_ $t>](&self, x: &[$t], out: &mut [$t]) {
507 assert_eq!(
508 x.len(),
509 out.len(),
510 "Input and output slices must have same length"
511 );
512 for (o, xi) in out.iter_mut().zip(x.iter()) {
513 *o = (*xi ).cos() ;
514 }
515 }
516 }
517 )+
518 };
519}
520
521#[macro_export]
523macro_rules! impl_v_cos_half {
524 () => {
525 #[inline(always)]
526 fn v_cos_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
527 assert_eq!(
528 x.len(),
529 out.len(),
530 "Input and output slices must have same length"
531 );
532 for (o, xi) in out.iter_mut().zip(x.iter()) {
533 *o = half::f16::from_f32(xi.to_f32().cos());
534 }
535 }
536
537 #[inline(always)]
538 fn v_cos_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
539 assert_eq!(
540 x.len(),
541 out.len(),
542 "Input and output slices must have same length"
543 );
544 for (o, xi) in out.iter_mut().zip(x.iter()) {
545 *o = half::bf16::from_f32(xi.to_f32().cos());
546 }
547 }
548 };
549}
550
551#[macro_export]
553macro_rules! impl_v_tanh {
554 ($($t:ty),+) => {
555 $(
556 paste::paste! {
557 #[inline(always)]
558 fn [<v_tanh_ $t>](&self, x: &[$t], out: &mut [$t]) {
559 assert_eq!(
560 x.len(),
561 out.len(),
562 "Input and output slices must have same length"
563 );
564 for (o, xi) in out.iter_mut().zip(x.iter()) {
565 *o = (*xi ).tanh();
566 }
567 }
568 }
569 )+
570 };
571}
572
573#[macro_export]
575macro_rules! impl_v_tanh_half {
576 () => {
577 fn v_tanh_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
578 assert_eq!(
579 x.len(),
580 out.len(),
581 "Input and output slices must have same length"
582 );
583 for (o, xi) in out.iter_mut().zip(x.iter()) {
584 *o = half::f16::from_f32(xi.to_f32().tanh());
585 }
586 }
587
588 fn v_tanh_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
589 assert_eq!(
590 x.len(),
591 out.len(),
592 "Input and output slices must have same length"
593 );
594 for (o, xi) in out.iter_mut().zip(x.iter()) {
595 *o = half::bf16::from_f32(xi.to_f32().tanh());
596 }
597 }
598 };
599}
600
601#[macro_export]
603macro_rules! impl_v_log {
604 ($($t:ty),+) => {
605 $(
606 paste::paste! {
607 #[inline(always)]
608 fn [<v_log_ $t>](&self, x: &[$t], out: &mut [$t]) {
609 assert_eq!(
610 x.len(),
611 out.len(),
612 "Input and output slices must have same length"
613 );
614 for (o, xi) in out.iter_mut().zip(x.iter()) {
615 *o = (*xi ).ln();
616 }
617 }
618 }
619 )+
620 };
621}
622
623#[macro_export]
625macro_rules! impl_v_log_half {
626 () => {
627 #[inline(always)]
628 fn v_log_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
629 assert_eq!(
630 x.len(),
631 out.len(),
632 "Input and output slices must have same length"
633 );
634 for (o, xi) in out.iter_mut().zip(x.iter()) {
635 *o = half::f16::from_f32(xi.to_f32().ln());
636 }
637 }
638
639 #[inline(always)]
640 fn v_log_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
641 assert_eq!(
642 x.len(),
643 out.len(),
644 "Input and output slices must have same length"
645 );
646 for (o, xi) in out.iter_mut().zip(x.iter()) {
647 *o = half::bf16::from_f32(xi.to_f32().ln());
648 }
649 }
650 };
651}
652
653#[macro_export]
655macro_rules! impl_v_sqrt {
656 ($($t:ty),+) => {
657 $(
658 paste::paste! {
659 #[inline(always)]
660 fn [<v_sqrt_ $t>](&self, x: &[$t], out: &mut [$t]) {
661 assert_eq!(
662 x.len(),
663 out.len(),
664 "Input and output slices must have same length"
665 );
666 for (o, xi) in out.iter_mut().zip(x.iter()) {
667 *o = (*xi ).sqrt() ;
668 }
669 }
670 }
671 )+
672 };
673}
674
675#[macro_export]
677macro_rules! impl_v_sqrt_half {
678 () => {
679 #[inline(always)]
680 fn v_sqrt_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
681 assert_eq!(
682 x.len(),
683 out.len(),
684 "Input and output slices must have same length"
685 );
686 for (o, xi) in out.iter_mut().zip(x.iter()) {
687 *o = half::f16::from_f32(xi.to_f32().sqrt());
688 }
689 }
690
691 #[inline(always)]
692 fn v_sqrt_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
693 assert_eq!(
694 x.len(),
695 out.len(),
696 "Input and output slices must have same length"
697 );
698 for (o, xi) in out.iter_mut().zip(x.iter()) {
699 *o = half::bf16::from_f32(xi.to_f32().sqrt());
700 }
701 }
702 };
703}
704
705#[macro_export]
707macro_rules! impl_v_sqr {
708 ($($t:ty),+) => {
709 $(
710 paste::paste! {
711 #[inline(always)]
712 fn [<v_sqr_ $t>](&self, x: &[$t], out: &mut [$t]) {
713 assert_eq!(
714 x.len(),
715 out.len(),
716 "Input and output slices must have same length"
717 );
718 for (o, xi) in out.iter_mut().zip(x.iter()) {
719 *o = (*xi) * (*xi);
720 }
721 }
722 }
723 )+
724 };
725}
726
727#[macro_export]
729macro_rules! impl_v_sqr_half {
730 () => {
731 #[inline(always)]
732 fn v_sqr_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
733 assert_eq!(
734 x.len(),
735 out.len(),
736 "Input and output slices must have same length"
737 );
738 for (o, xi) in out.iter_mut().zip(x.iter()) {
739 *o = xi * xi;
740 }
741 }
742
743 #[inline(always)]
744 fn v_sqr_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
745 assert_eq!(
746 x.len(),
747 out.len(),
748 "Input and output slices must have same length"
749 );
750 for (o, xi) in out.iter_mut().zip(x.iter()) {
751 *o = xi * xi;
752 }
753 }
754 };
755}
756
757#[macro_export]
759macro_rules! impl_v_add {
760 ($($t:ty),+) => {
761 $(
762 paste::paste! {
763 #[inline(always)]
764 fn [<v_add_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
765 assert_eq!(a.len(), b.len(), "Input slices must have same length");
766 assert_eq!(
767 a.len(),
768 out.len(),
769 "Input and output slices must have same length"
770 );
771 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
772 *o = *ai + *bi;
773 }
774 }
775 }
776 )+
777 };
778}
779
780#[macro_export]
782macro_rules! impl_v_add_half {
783 () => {
784 #[inline(always)]
785 fn v_add_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
786 assert_eq!(a.len(), b.len(), "Input slices must have same length");
787 assert_eq!(
788 a.len(),
789 out.len(),
790 "Input and output slices must have same length"
791 );
792 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
793 *o = ai + bi;
794 }
795 }
796
797 #[inline(always)]
798 fn v_add_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
799 assert_eq!(a.len(), b.len(), "Input slices must have same length");
800 assert_eq!(
801 a.len(),
802 out.len(),
803 "Input and output slices must have same length"
804 );
805 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
806 *o = ai + bi;
807 }
808 }
809 };
810}
811
812#[macro_export]
814macro_rules! impl_v_sub {
815 ($($t:ty),+) => {
816 $(
817 paste::paste! {
818 #[inline(always)]
819 fn [<v_sub_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
820 assert_eq!(a.len(), b.len(), "Input slices must have same length");
821 assert_eq!(
822 a.len(),
823 out.len(),
824 "Input and output slices must have same length"
825 );
826 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
827 *o = *ai - *bi;
828 }
829 }
830 }
831 )+
832 };
833}
834
835#[macro_export]
837macro_rules! impl_v_sub_half {
838 () => {
839 #[inline(always)]
840 fn v_sub_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
841 assert_eq!(a.len(), b.len(), "Input slices must have same length");
842 assert_eq!(
843 a.len(),
844 out.len(),
845 "Input and output slices must have same length"
846 );
847 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
848 *o = ai - bi;
849 }
850 }
851
852 #[inline(always)]
853 fn v_sub_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
854 assert_eq!(a.len(), b.len(), "Input slices must have same length");
855 assert_eq!(
856 a.len(),
857 out.len(),
858 "Input and output slices must have same length"
859 );
860 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
861 *o = ai - bi;
862 }
863 }
864 };
865}
866
867#[macro_export]
869macro_rules! impl_v_mul {
870 ($($t:ty),+) => {
871 $(
872 paste::paste! {
873 #[inline(always)]
874 fn [<v_mul_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
875 assert_eq!(a.len(), b.len(), "Input slices must have same length");
876 assert_eq!(
877 a.len(),
878 out.len(),
879 "Input and output slices must have same length"
880 );
881 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
882 *o = *ai * *bi;
883 }
884 }
885 }
886 )+
887 };
888}
889
890#[macro_export]
892macro_rules! impl_v_mul_half {
893 () => {
894 #[inline(always)]
895 fn v_mul_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
896 assert_eq!(a.len(), b.len(), "Input slices must have same length");
897 assert_eq!(
898 a.len(),
899 out.len(),
900 "Input and output slices must have same length"
901 );
902 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
903 *o = (*ai) * (*bi);
904 }
905 }
906
907 #[inline(always)]
908 fn v_mul_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
909 assert_eq!(a.len(), b.len(), "Input slices must have same length");
910 assert_eq!(
911 a.len(),
912 out.len(),
913 "Input and output slices must have same length"
914 );
915 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
916 *o = (*ai) * (*bi);
917 }
918 }
919 };
920}
921
922#[macro_export]
924macro_rules! impl_v_div {
925 ($($t:ty),+) => {
926 $(
927 paste::paste! {
928 #[inline(always)]
929 fn [<v_div_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
930 assert_eq!(a.len(), b.len(), "Input slices must have same length");
931 assert_eq!(
932 a.len(),
933 out.len(),
934 "Input and output slices must have same length"
935 );
936 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
937 *o = *ai / *bi;
938 }
939 }
940 }
941 )+
942 };
943}
944
945#[macro_export]
947macro_rules! impl_v_div_half {
948 () => {
949 #[inline(always)]
950 fn v_div_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
951 assert_eq!(a.len(), b.len(), "Input slices must have same length");
952 assert_eq!(
953 a.len(),
954 out.len(),
955 "Input and output slices must have same length"
956 );
957 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
958 *o = (*ai) / (*bi);
959 }
960 }
961
962 #[inline(always)]
963 fn v_div_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
964 assert_eq!(a.len(), b.len(), "Input slices must have same length");
965 assert_eq!(
966 a.len(),
967 out.len(),
968 "Input and output slices must have same length"
969 );
970 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
971 *o = (*ai) / (*bi);
972 }
973 }
974 };
975}
976
977#[macro_export]
979macro_rules! impl_v_div_scalar {
980 ($($t:ty),+) => {
981 $(
982 paste::paste! {
983 #[inline(always)]
984 fn [<v_div_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
985 assert_eq!(
986 x.len(),
987 out.len(),
988 "Input and output slices must have same length"
989 );
990 for (o, xi) in out.iter_mut().zip(x.iter()) {
991 *o = *xi / scalar;
992 }
993 }
994 }
995 )+
996 };
997}
998
999#[macro_export]
1001macro_rules! impl_v_div_scalar_half {
1002 () => {
1003 #[inline(always)]
1004 fn v_div_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1005 assert_eq!(
1006 x.len(),
1007 out.len(),
1008 "Input and output slices must have same length"
1009 );
1010 for (o, xi) in out.iter_mut().zip(x.iter()) {
1011 *o = (*xi) / scalar;
1012 }
1013 }
1014
1015 #[inline(always)]
1016 fn v_div_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1017 assert_eq!(
1018 x.len(),
1019 out.len(),
1020 "Input and output slices must have same length"
1021 );
1022 for (o, xi) in out.iter_mut().zip(x.iter()) {
1023 *o = (*xi) / scalar;
1024 }
1025 }
1026 };
1027}
1028
1029#[macro_export]
1031macro_rules! impl_v_tan {
1032 ($($t:ty),+) => {
1033 $(
1034 paste::paste! {
1035 #[inline(always)]
1036 fn [<v_tan_ $t>](&self, x: &[$t], out: &mut [$t]) {
1037 assert_eq!(
1038 x.len(),
1039 out.len(),
1040 "Input and output slices must have same length"
1041 );
1042 for (o, xi) in out.iter_mut().zip(x.iter()) {
1043 *o = (*xi ).tan() ;
1044 }
1045 }
1046 }
1047 )+
1048 };
1049}
1050
1051#[macro_export]
1053macro_rules! impl_v_tan_half {
1054 () => {
1055 #[inline(always)]
1056 fn v_tan_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1057 assert_eq!(
1058 x.len(),
1059 out.len(),
1060 "Input and output slices must have same length"
1061 );
1062 for (o, xi) in out.iter_mut().zip(x.iter()) {
1063 *o = half::f16::from_f32(xi.to_f32().tan());
1064 }
1065 }
1066
1067 #[inline(always)]
1068 fn v_tan_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1069 assert_eq!(
1070 x.len(),
1071 out.len(),
1072 "Input and output slices must have same length"
1073 );
1074 for (o, xi) in out.iter_mut().zip(x.iter()) {
1075 *o = half::bf16::from_f32(xi.to_f32().tan());
1076 }
1077 }
1078 };
1079}
1080
1081#[macro_export]
1083macro_rules! impl_v_recip {
1084 ($($t:ty),+) => {
1085 $(
1086 paste::paste! {
1087 #[inline(always)]
1088 fn [<v_recip_ $t>](&self, x: &[$t], out: &mut [$t]) {
1089 assert_eq!(
1090 x.len(),
1091 out.len(),
1092 "Input and output slices must have same length"
1093 );
1094 for (o, xi) in out.iter_mut().zip(x.iter()) {
1095 *o = (1.0 / (*xi ));
1096 }
1097 }
1098 }
1099 )+
1100 };
1101}
1102
1103#[macro_export]
1105macro_rules! impl_v_recip_half {
1106 () => {
1107 #[inline(always)]
1108 fn v_recip_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1109 assert_eq!(
1110 x.len(),
1111 out.len(),
1112 "Input and output slices must have same length"
1113 );
1114 for (o, xi) in out.iter_mut().zip(x.iter()) {
1115 *o = half::f16::ONE / xi;
1116 }
1117 }
1118
1119 #[inline(always)]
1120 fn v_recip_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1121 assert_eq!(
1122 x.len(),
1123 out.len(),
1124 "Input and output slices must have same length"
1125 );
1126 for (o, xi) in out.iter_mut().zip(x.iter()) {
1127 *o = half::bf16::ONE / xi;
1128 }
1129 }
1130 };
1131}
1132
1133#[macro_export]
1135macro_rules! impl_v_floor {
1136 ($($t:ty),+) => {
1137 $(
1138 paste::paste! {
1139 #[inline(always)]
1140 fn [<v_floor_ $t>](&self, x: &[$t], out: &mut [$t]) {
1141 assert_eq!(
1142 x.len(),
1143 out.len(),
1144 "Input and output slices must have same length"
1145 );
1146 for (o, xi) in out.iter_mut().zip(x.iter()) {
1147 *o = (*xi ).floor() ;
1148 }
1149 }
1150 }
1151 )+
1152 };
1153}
1154
1155#[macro_export]
1157macro_rules! impl_v_floor_half {
1158 () => {
1159 #[inline(always)]
1160 fn v_floor_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1161 assert_eq!(
1162 x.len(),
1163 out.len(),
1164 "Input and output slices must have same length"
1165 );
1166 for (o, xi) in out.iter_mut().zip(x.iter()) {
1167 *o = half::f16::from_f32(xi.to_f32().floor());
1168 }
1169 }
1170
1171 #[inline(always)]
1172 fn v_floor_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1173 assert_eq!(
1174 x.len(),
1175 out.len(),
1176 "Input and output slices must have same length"
1177 );
1178 for (o, xi) in out.iter_mut().zip(x.iter()) {
1179 *o = half::bf16::from_f32(xi.to_f32().floor());
1180 }
1181 }
1182 };
1183}
1184
1185#[macro_export]
1187macro_rules! impl_v_ceil {
1188 ($($t:ty),+) => {
1189 $(
1190 paste::paste! {
1191 #[inline(always)]
1192 fn [<v_ceil_ $t>](&self, x: &[$t], out: &mut [$t]) {
1193 assert_eq!(
1194 x.len(),
1195 out.len(),
1196 "Input and output slices must have same length"
1197 );
1198 for (o, xi) in out.iter_mut().zip(x.iter()) {
1199 *o = (*xi ).ceil();
1200 }
1201 }
1202 }
1203 )+
1204 };
1205}
1206
1207#[macro_export]
1209macro_rules! impl_v_ceil_half {
1210 () => {
1211 #[inline(always)]
1212 fn v_ceil_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1213 assert_eq!(
1214 x.len(),
1215 out.len(),
1216 "Input and output slices must have same length"
1217 );
1218 for (o, xi) in out.iter_mut().zip(x.iter()) {
1219 *o = half::f16::from_f32(xi.to_f32().ceil());
1220 }
1221 }
1222
1223 #[inline(always)]
1224 fn v_ceil_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1225 assert_eq!(
1226 x.len(),
1227 out.len(),
1228 "Input and output slices must have same length"
1229 );
1230 for (o, xi) in out.iter_mut().zip(x.iter()) {
1231 *o = half::bf16::from_f32(xi.to_f32().ceil());
1232 }
1233 }
1234 };
1235}
1236
1237#[macro_export]
1239macro_rules! impl_v_round {
1240 ($($t:ty),+) => {
1241 $(
1242 paste::paste! {
1243 #[inline(always)]
1244 fn [<v_round_ $t>](&self, x: &[$t], out: &mut [$t]) {
1245 assert_eq!(
1246 x.len(),
1247 out.len(),
1248 "Input and output slices must have same length"
1249 );
1250 for (o, xi) in out.iter_mut().zip(x.iter()) {
1251 *o = (*xi ).round() ;
1252 }
1253 }
1254 }
1255 )+
1256 };
1257}
1258
1259#[macro_export]
1261macro_rules! impl_v_round_half {
1262 () => {
1263 #[inline(always)]
1264 fn v_round_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1265 assert_eq!(
1266 x.len(),
1267 out.len(),
1268 "Input and output slices must have same length"
1269 );
1270 for (o, xi) in out.iter_mut().zip(x.iter()) {
1271 *o = half::f16::from_f32(xi.to_f32().round());
1272 }
1273 }
1274
1275 #[inline(always)]
1276 fn v_round_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1277 assert_eq!(
1278 x.len(),
1279 out.len(),
1280 "Input and output slices must have same length"
1281 );
1282 for (o, xi) in out.iter_mut().zip(x.iter()) {
1283 *o = half::bf16::from_f32(xi.to_f32().round());
1284 }
1285 }
1286 };
1287}
1288
1289#[macro_export]
1291macro_rules! impl_v_abs {
1292 ($($t:ty),+) => {
1293 $(
1294 paste::paste! {
1295 #[inline(always)]
1296 fn [<v_abs_ $t>](&self, x: &[$t], out: &mut [$t]) {
1297 assert_eq!(
1298 x.len(),
1299 out.len(),
1300 "Input and output slices must have same length"
1301 );
1302 for (o, xi) in out.iter_mut().zip(x.iter()) {
1303 *o = (*xi).abs();
1304 }
1305 }
1306 }
1307 )+
1308 };
1309}
1310
1311#[macro_export]
1313macro_rules! impl_v_abs_half {
1314 () => {
1315 #[inline(always)]
1316 fn v_abs_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1317 assert_eq!(
1318 x.len(),
1319 out.len(),
1320 "Input and output slices must have same length"
1321 );
1322 for (o, xi) in out.iter_mut().zip(x.iter()) {
1323 *o = half::f16::from_f32(xi.to_f32().abs());
1324 }
1325 }
1326
1327 #[inline(always)]
1328 fn v_abs_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1329 assert_eq!(
1330 x.len(),
1331 out.len(),
1332 "Input and output slices must have same length"
1333 );
1334 for (o, xi) in out.iter_mut().zip(x.iter()) {
1335 *o = half::bf16::from_f32(xi.to_f32().abs());
1336 }
1337 }
1338 };
1339}
1340
1341#[macro_export]
1343macro_rules! impl_v_neg {
1344 ($($t:ty),+) => {
1345 $(
1346 paste::paste! {
1347 #[inline(always)]
1348 fn [<v_neg_ $t>](&self, x: &[$t], out: &mut [$t]) {
1349 assert_eq!(
1350 x.len(),
1351 out.len(),
1352 "Input and output slices must have same length"
1353 );
1354 for (o, xi) in out.iter_mut().zip(x.iter()) {
1355 *o = -(*xi);
1356 }
1357 }
1358 }
1359 )+
1360 };
1361}
1362
1363#[macro_export]
1365macro_rules! impl_v_neg_half {
1366 () => {
1367 #[inline(always)]
1368 fn v_neg_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1369 assert_eq!(
1370 x.len(),
1371 out.len(),
1372 "Input and output slices must have same length"
1373 );
1374 for (o, xi) in out.iter_mut().zip(x.iter()) {
1375 *o = -(*xi);
1376 }
1377 }
1378
1379 #[inline(always)]
1380 fn v_neg_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1381 assert_eq!(
1382 x.len(),
1383 out.len(),
1384 "Input and output slices must have same length"
1385 );
1386 for (o, xi) in out.iter_mut().zip(x.iter()) {
1387 *o = -(*xi);
1388 }
1389 }
1390 };
1391}
1392
1393#[macro_export]
1395macro_rules! impl_v_pow {
1396 ($($t:ty),+) => {
1397 $(
1398 paste::paste! {
1399 #[inline(always)]
1400 fn [<v_pow_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
1401 assert_eq!(a.len(), b.len(), "Input slices must have same length");
1402 assert_eq!(
1403 a.len(),
1404 out.len(),
1405 "Input and output slices must have same length"
1406 );
1407 for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
1408 *o = ((*ai ).powf(*bi )) ;
1409 }
1410 }
1411 }
1412 )+
1413 };
1414}
1415
1416#[macro_export]
1418macro_rules! impl_relu {
1419 ($($t:ty),+) => {
1420 $(
1421 paste::paste! {
1422 #[inline(always)]
1423 fn [<relu_ $t>](&self, x: &[$t], out: &mut [$t]) {
1424 assert_eq!(
1425 x.len(),
1426 out.len(),
1427 "Input and output slices must have same length"
1428 );
1429 for (o, xi) in out.iter_mut().zip(x.iter()) {
1430 *o = (*xi).max(0.0);
1431 }
1432 }
1433 }
1434 )+
1435 };
1436}
1437
1438#[macro_export]
1440macro_rules! impl_relu_int {
1441 ($($t:ty),+) => {
1442 $(
1443 paste::paste! {
1444 #[inline(always)]
1445 fn [<relu_ $t>](&self, x: &[$t], out: &mut [$t]) {
1446 assert_eq!(
1447 x.len(),
1448 out.len(),
1449 "Input and output slices must have same length"
1450 );
1451 for (o, xi) in out.iter_mut().zip(x.iter()) {
1452 *o = (*xi).max(0);
1453 }
1454 }
1455 }
1456 )+
1457 };
1458}
1459
1460#[macro_export]
1462macro_rules! impl_relu_uint {
1463 ($($t:ty),+) => {
1464 $(
1465 paste::paste! {
1466 #[inline(always)]
1467 fn [<relu_ $t>](&self, x: &[$t], out: &mut [$t]) {
1468 assert_eq!(
1469 x.len(),
1470 out.len(),
1471 "Input and output slices must have same length"
1472 );
1473 for (o, xi) in out.iter_mut().zip(x.iter()) {
1474 *o = *xi;
1475 }
1476 }
1477 }
1478 )+
1479 };
1480}
1481
1482#[macro_export]
1484macro_rules! impl_relu_half {
1485 () => {
1486 #[inline(always)]
1487 fn relu_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1488 assert_eq!(
1489 x.len(),
1490 out.len(),
1491 "Input and output slices must have same length"
1492 );
1493 for (o, xi) in out.iter_mut().zip(x.iter()) {
1494 *o = xi.max(half::f16::ZERO);
1495 }
1496 }
1497
1498 #[inline(always)]
1499 fn relu_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1500 assert_eq!(
1501 x.len(),
1502 out.len(),
1503 "Input and output slices must have same length"
1504 );
1505 for (o, xi) in out.iter_mut().zip(x.iter()) {
1506 *o = xi.max(half::bf16::ZERO);
1507 }
1508 }
1509 };
1510}
1511
1512#[macro_export]
1514macro_rules! impl_sum_int {
1515 ($($t:ty => $acc:ty),+) => {
1516 $(
1517 paste::paste! {
1518 #[inline(always)]
1519 fn [<sum_ $t>](&self, x: &[$t]) -> f64 {
1520 if x.is_empty() {
1521 return 0.0;
1522 }
1523 let mut sum: $acc = 0;
1524 for &val in x {
1525 sum += val as $acc;
1526 }
1527 sum as f64
1528 }
1529 }
1530 )+
1531 };
1532}
1533
1534#[macro_export]
1536macro_rules! impl_mean_float {
1537 ($($t:ty),+) => {
1538 $(
1539 paste::paste! {
1540 #[inline(always)]
1541 fn [<mean_ $t>](&self, x: &[$t]) -> $t {
1542 if x.is_empty() {
1543 return 0 as $t;
1544 }
1545 let sum = self.[<sum_ $t>](x);
1546 sum / (x.len() as $t)
1547 }
1548 }
1549 )+
1550 };
1551}
1552
1553#[macro_export]
1555macro_rules! impl_mean_half {
1556 () => {
1557 #[inline(always)]
1558 fn mean_f16(&self, x: &[half::f16]) -> f64 {
1559 if x.is_empty() {
1560 return 0.0f64;
1561 }
1562 let sum = self.sum_f16(x);
1563 sum / (x.len() as f64)
1564 }
1565
1566 #[inline(always)]
1567 fn mean_bf16(&self, x: &[half::bf16]) -> f64 {
1568 if x.is_empty() {
1569 return 0.0f64;
1570 }
1571 let sum = self.sum_bf16(x);
1572 sum / (x.len() as f64)
1573 }
1574 };
1575}
1576
1577#[macro_export]
1579macro_rules! impl_mean_int {
1580 ($($t:ty),+) => {
1581 $(
1582 paste::paste! {
1583 #[inline(always)]
1584 fn [<mean_ $t>](&self, x: &[$t]) -> f64 {
1585 if x.is_empty() {
1586 return 0.0f64;
1587 }
1588 let sum = self.[<sum_ $t>](x);
1589 sum / (x.len() as f64)
1590 }
1591 }
1592 )+
1593 };
1594}
1595
1596#[macro_export]
1598macro_rules! impl_v_add_scalar {
1599 ($($t:ty),+) => {
1600 $(
1601 paste::paste! {
1602 #[inline(always)]
1603 fn [<v_add_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
1604 assert_eq!(
1605 x.len(),
1606 out.len(),
1607 "Input and output slices must have same length"
1608 );
1609 for (o, xi) in out.iter_mut().zip(x.iter()) {
1610 *o = *xi + scalar;
1611 }
1612 }
1613 }
1614 )+
1615 };
1616}
1617
1618#[macro_export]
1620macro_rules! impl_v_add_scalar_half {
1621 () => {
1622 #[inline(always)]
1623 fn v_add_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1624 assert_eq!(
1625 x.len(),
1626 out.len(),
1627 "Input and output slices must have same length"
1628 );
1629 for (o, xi) in out.iter_mut().zip(x.iter()) {
1630 *o = *xi + scalar;
1631 }
1632 }
1633
1634 #[inline(always)]
1635 fn v_add_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1636 assert_eq!(
1637 x.len(),
1638 out.len(),
1639 "Input and output slices must have same length"
1640 );
1641 for (o, xi) in out.iter_mut().zip(x.iter()) {
1642 *o = *xi + scalar;
1643 }
1644 }
1645 };
1646}
1647
1648#[macro_export]
1650macro_rules! impl_v_sub_scalar {
1651 ($($t:ty),+) => {
1652 $(
1653 paste::paste! {
1654 #[inline(always)]
1655 fn [<v_sub_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
1656 assert_eq!(
1657 x.len(),
1658 out.len(),
1659 "Input and output slices must have same length"
1660 );
1661 for (o, xi) in out.iter_mut().zip(x.iter()) {
1662 *o = *xi - scalar;
1663 }
1664 }
1665 }
1666 )+
1667 };
1668}
1669
1670#[macro_export]
1672macro_rules! impl_v_sub_scalar_half {
1673 () => {
1674 #[inline(always)]
1675 fn v_sub_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1676 assert_eq!(
1677 x.len(),
1678 out.len(),
1679 "Input and output slices must have same length"
1680 );
1681 for (o, xi) in out.iter_mut().zip(x.iter()) {
1682 *o = *xi - scalar;
1683 }
1684 }
1685
1686 #[inline(always)]
1687 fn v_sub_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1688 assert_eq!(
1689 x.len(),
1690 out.len(),
1691 "Input and output slices must have same length"
1692 );
1693 for (o, xi) in out.iter_mut().zip(x.iter()) {
1694 *o = *xi - scalar;
1695 }
1696 }
1697 };
1698}
1699
1700#[macro_export]
1702macro_rules! impl_v_mul_scalar {
1703 ($($t:ty),+) => {
1704 $(
1705 paste::paste! {
1706 #[inline(always)]
1707 fn [<v_mul_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
1708 assert_eq!(
1709 x.len(),
1710 out.len(),
1711 "Input and output slices must have same length"
1712 );
1713 for (o, xi) in out.iter_mut().zip(x.iter()) {
1714 *o = *xi * scalar;
1715 }
1716 }
1717 }
1718 )+
1719 };
1720}
1721
1722#[macro_export]
1724macro_rules! impl_v_mul_scalar_half {
1725 () => {
1726 #[inline(always)]
1727 fn v_mul_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1728 assert_eq!(
1729 x.len(),
1730 out.len(),
1731 "Input and output slices must have same length"
1732 );
1733 for (o, xi) in out.iter_mut().zip(x.iter()) {
1734 *o = *xi * scalar;
1735 }
1736 }
1737
1738 #[inline(always)]
1739 fn v_mul_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1740 assert_eq!(
1741 x.len(),
1742 out.len(),
1743 "Input and output slices must have same length"
1744 );
1745 for (o, xi) in out.iter_mut().zip(x.iter()) {
1746 *o = *xi * scalar;
1747 }
1748 }
1749 };
1750}
1751
1752#[macro_export]
1754macro_rules! impl_max_v {
1755 ($($t:ty),+) => {
1756 $(
1757 paste::paste! {
1758 #[inline(always)]
1759 fn [<max_v_ $t>](&self, x: &[$t]) -> $t {
1760 if x.is_empty() {
1761 panic!("Cannot find maximum of empty vector");
1762 }
1763 *x.iter().max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0])
1764 }
1765 }
1766 )+
1767 };
1768}
1769
1770#[macro_export]
1772macro_rules! impl_min_v {
1773 ($($t:ty),+) => {
1774 $(
1775 paste::paste! {
1776 #[inline(always)]
1777 fn [<min_v_ $t>](&self, x: &[$t]) -> $t {
1778 if x.is_empty() {
1779 panic!("Cannot find minimum of empty vector");
1780 }
1781 *x.iter().min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0])
1782 }
1783 }
1784 )+
1785 };
1786}
1787
1788#[macro_export]
1790macro_rules! impl_max_vi {
1791 ($($t:ty),+) => {
1792 $(
1793 paste::paste! {
1794 #[inline(always)]
1795 fn [<max_vi_ $t>](&self, x: &[$t]) -> ($t, u64) {
1796 if x.is_empty() {
1797 panic!("Cannot find maximum of empty vector");
1798 }
1799 let (idx, val) = x.iter()
1800 .enumerate()
1801 .max_by(|a, b| {
1802 match a.1.partial_cmp(b.1) {
1803 Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), Some(ordering) => ordering,
1805 None => std::cmp::Ordering::Equal,
1806 }
1807 })
1808 .unwrap_or((0, &x[0]));
1809 (*val, idx as u64)
1810 }
1811 }
1812 )+
1813 };
1814}
1815
1816#[macro_export]
1818macro_rules! impl_min_vi {
1819 ($($t:ty),+) => {
1820 $(
1821 paste::paste! {
1822 #[inline(always)]
1823 fn [<min_vi_ $t>](&self, x: &[$t]) -> ($t, u64) {
1824 if x.is_empty() {
1825 panic!("Cannot find minimum of empty vector");
1826 }
1827 let (idx, val) = x.iter()
1828 .enumerate()
1829 .min_by(|a, b| {
1830 match a.1.partial_cmp(b.1) {
1831 Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), Some(ordering) => ordering,
1833 None => std::cmp::Ordering::Equal,
1834 }
1835 })
1836 .unwrap_or((0, &x[0]));
1837 (*val, idx as u64)
1838 }
1839 }
1840 )+
1841 };
1842}
1843
1844#[macro_export]
1846macro_rules! impl_max_v_half {
1847 () => {
1848 #[inline(always)]
1849 fn max_v_f16(&self, x: &[half::f16]) -> half::f16 {
1850 if x.is_empty() {
1851 panic!("Cannot find maximum of empty vector");
1852 }
1853 *x.iter()
1854 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1855 .unwrap_or(&x[0])
1856 }
1857
1858 #[inline(always)]
1859 fn max_v_bf16(&self, x: &[half::bf16]) -> half::bf16 {
1860 if x.is_empty() {
1861 panic!("Cannot find maximum of empty vector");
1862 }
1863 *x.iter()
1864 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1865 .unwrap_or(&x[0])
1866 }
1867 };
1868}
1869
1870#[macro_export]
1872macro_rules! impl_min_v_half {
1873 () => {
1874 #[inline(always)]
1875 fn min_v_f16(&self, x: &[half::f16]) -> half::f16 {
1876 if x.is_empty() {
1877 panic!("Cannot find minimum of empty vector");
1878 }
1879 *x.iter()
1880 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1881 .unwrap_or(&x[0])
1882 }
1883
1884 #[inline(always)]
1885 fn min_v_bf16(&self, x: &[half::bf16]) -> half::bf16 {
1886 if x.is_empty() {
1887 panic!("Cannot find minimum of empty vector");
1888 }
1889 *x.iter()
1890 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1891 .unwrap_or(&x[0])
1892 }
1893 };
1894}
1895
1896#[macro_export]
1898macro_rules! impl_max_vi_half {
1899 () => {
1900 #[inline(always)]
1901 fn max_vi_f16(&self, x: &[half::f16]) -> (half::f16, u64) {
1902 if x.is_empty() {
1903 panic!("Cannot find maximum of empty vector");
1904 }
1905 let (idx, val) = x
1906 .iter()
1907 .enumerate()
1908 .max_by(|a, b| {
1909 match a.1.partial_cmp(b.1) {
1910 Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), Some(ordering) => ordering,
1912 None => std::cmp::Ordering::Equal,
1913 }
1914 })
1915 .unwrap_or((0, &x[0]));
1916 (*val, idx as u64)
1917 }
1918
1919 #[inline(always)]
1920 fn max_vi_bf16(&self, x: &[half::bf16]) -> (half::bf16, u64) {
1921 if x.is_empty() {
1922 panic!("Cannot find maximum of empty vector");
1923 }
1924 let (idx, val) = x
1925 .iter()
1926 .enumerate()
1927 .max_by(|a, b| {
1928 match a.1.partial_cmp(b.1) {
1929 Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), Some(ordering) => ordering,
1931 None => std::cmp::Ordering::Equal,
1932 }
1933 })
1934 .unwrap_or((0, &x[0]));
1935 (*val, idx as u64)
1936 }
1937 };
1938}
1939
1940#[macro_export]
1942macro_rules! impl_min_vi_half {
1943 () => {
1944 #[inline(always)]
1945 fn min_vi_f16(&self, x: &[half::f16]) -> (half::f16, u64) {
1946 if x.is_empty() {
1947 panic!("Cannot find minimum of empty vector");
1948 }
1949 let (idx, val) = x
1950 .iter()
1951 .enumerate()
1952 .min_by(|a, b| {
1953 match a.1.partial_cmp(b.1) {
1954 Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), Some(ordering) => ordering,
1956 None => std::cmp::Ordering::Equal,
1957 }
1958 })
1959 .unwrap_or((0, &x[0]));
1960 (*val, idx as u64)
1961 }
1962
1963 #[inline(always)]
1964 fn min_vi_bf16(&self, x: &[half::bf16]) -> (half::bf16, u64) {
1965 if x.is_empty() {
1966 panic!("Cannot find minimum of empty vector");
1967 }
1968 let (idx, val) = x
1969 .iter()
1970 .enumerate()
1971 .min_by(|a, b| {
1972 match a.1.partial_cmp(b.1) {
1973 Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), Some(ordering) => ordering,
1975 None => std::cmp::Ordering::Equal,
1976 }
1977 })
1978 .unwrap_or((0, &x[0]));
1979 (*val, idx as u64)
1980 }
1981 };
1982}
1983
1984#[macro_export]
1986macro_rules! impl_min_max_v {
1987 ($($t:ty),+) => {
1988 $(
1989 paste::paste! {
1990 #[inline(always)]
1991 fn [<min_max_v_ $t>](&self, x: &[$t]) -> ($t, $t) {
1992 if x.is_empty() {
1993 panic!("Cannot find min/max of empty vector");
1994 }
1995 let min_val = *x.iter().min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0]);
1996 let max_val = *x.iter().max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0]);
1997 (min_val, max_val)
1998 }
1999 }
2000 )+
2001 };
2002}
2003
2004#[macro_export]
2006macro_rules! impl_min_max_vi {
2007 ($($t:ty),+) => {
2008 $(
2009 paste::paste! {
2010 #[inline(always)]
2011 fn [<min_max_vi_ $t>](&self, x: &[$t]) -> (($t, u64), ($t, u64)) {
2012 if x.is_empty() {
2013 panic!("Cannot find min/max of empty vector");
2014 }
2015 let mut min_val = x[0];
2016 let mut min_idx = 0;
2017 let mut max_val = x[0];
2018 let mut max_idx = 0;
2019
2020 for (i, &val) in x.iter().enumerate() {
2021 if val < min_val {
2022 min_val = val;
2023 min_idx = i;
2024 }
2025 if val > max_val {
2026 max_val = val;
2027 max_idx = i;
2028 }
2029 }
2030
2031 ((min_val, min_idx as u64), (max_val, max_idx as u64))
2032 }
2033 }
2034 )+
2035 };
2036}
2037
2038#[macro_export]
2040macro_rules! impl_min_max_i {
2041 ($($t:ty),+) => {
2042 $(
2043 paste::paste! {
2044 #[inline(always)]
2045 fn [<min_max_i_ $t>](&self, x: &[$t]) -> (u64, u64) {
2046 if x.is_empty() {
2047 panic!("Cannot find min/max indices of empty vector");
2048 }
2049 let min_idx = x.iter()
2050 .enumerate()
2051 .min_by(|a, b| {
2052 match a.1.partial_cmp(b.1) {
2053 Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), Some(ordering) => ordering,
2055 None => std::cmp::Ordering::Equal,
2056 }
2057 })
2058 .unwrap_or((0, &x[0]))
2059 .0 as u64;
2060 let max_idx = x.iter()
2061 .enumerate()
2062 .max_by(|a, b| {
2063 match a.1.partial_cmp(b.1) {
2064 Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), Some(ordering) => ordering,
2066 None => std::cmp::Ordering::Equal,
2067 }
2068 })
2069 .unwrap_or((0, &x[0]))
2070 .0 as u64;
2071 (min_idx, max_idx)
2072 }
2073 }
2074 )+
2075 };
2076}
2077
2078#[macro_export]
2080macro_rules! impl_min_max_v_half {
2081 () => {
2082 #[inline(always)]
2083 fn min_max_v_f16(&self, x: &[half::f16]) -> (half::f16, half::f16) {
2084 if x.is_empty() {
2085 panic!("Cannot find min/max of empty vector");
2086 }
2087 let min_val = *x
2088 .iter()
2089 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2090 .unwrap_or(&x[0]);
2091 let max_val = *x
2092 .iter()
2093 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2094 .unwrap_or(&x[0]);
2095 (min_val, max_val)
2096 }
2097
2098 #[inline(always)]
2099 fn min_max_v_bf16(&self, x: &[half::bf16]) -> (half::bf16, half::bf16) {
2100 if x.is_empty() {
2101 panic!("Cannot find min/max of empty vector");
2102 }
2103 let min_val = *x
2104 .iter()
2105 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2106 .unwrap_or(&x[0]);
2107 let max_val = *x
2108 .iter()
2109 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2110 .unwrap_or(&x[0]);
2111 (min_val, max_val)
2112 }
2113 };
2114}
2115
2116#[macro_export]
2118macro_rules! impl_min_max_vi_half {
2119 () => {
2120 #[inline(always)]
2121 fn min_max_vi_f16(&self, x: &[half::f16]) -> ((half::f16, u64), (half::f16, u64)) {
2122 if x.is_empty() {
2123 panic!("Cannot find min/max of empty vector");
2124 }
2125 let mut min_val = x[0];
2126 let mut min_idx = 0;
2127 let mut max_val = x[0];
2128 let mut max_idx = 0;
2129
2130 for (i, &val) in x.iter().enumerate() {
2131 if val < min_val {
2132 min_val = val;
2133 min_idx = i;
2134 }
2135 if val > max_val {
2136 max_val = val;
2137 max_idx = i;
2138 }
2139 }
2140
2141 ((min_val, min_idx as u64), (max_val, max_idx as u64))
2142 }
2143
2144 #[inline(always)]
2145 fn min_max_vi_bf16(&self, x: &[half::bf16]) -> ((half::bf16, u64), (half::bf16, u64)) {
2146 if x.is_empty() {
2147 panic!("Cannot find min/max of empty vector");
2148 }
2149 let mut min_val = x[0];
2150 let mut min_idx = 0;
2151 let mut max_val = x[0];
2152 let mut max_idx = 0;
2153
2154 for (i, &val) in x.iter().enumerate() {
2155 if val < min_val {
2156 min_val = val;
2157 min_idx = i;
2158 }
2159 if val > max_val {
2160 max_val = val;
2161 max_idx = i;
2162 }
2163 }
2164
2165 ((min_val, min_idx as u64), (max_val, max_idx as u64))
2166 }
2167 };
2168}
2169
2170#[macro_export]
2172macro_rules! impl_min_max_i_half {
2173 () => {
2174 #[inline(always)]
2175 fn min_max_i_f16(&self, x: &[half::f16]) -> (u64, u64) {
2176 if x.is_empty() {
2177 panic!("Cannot find min/max indices of empty vector");
2178 }
2179 let min_idx = x
2180 .iter()
2181 .enumerate()
2182 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2183 .unwrap_or((0, &x[0]))
2184 .0 as u64;
2185 let max_idx = x
2186 .iter()
2187 .enumerate()
2188 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2189 .unwrap_or((0, &x[0]))
2190 .0 as u64;
2191 (min_idx, max_idx)
2192 }
2193
2194 #[inline(always)]
2195 fn min_max_i_bf16(&self, x: &[half::bf16]) -> (u64, u64) {
2196 if x.is_empty() {
2197 panic!("Cannot find min/max indices of empty vector");
2198 }
2199 let min_idx = x
2200 .iter()
2201 .enumerate()
2202 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2203 .unwrap_or((0, &x[0]))
2204 .0 as u64;
2205 let max_idx = x
2206 .iter()
2207 .enumerate()
2208 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2209 .unwrap_or((0, &x[0]))
2210 .0 as u64;
2211 (min_idx, max_idx)
2212 }
2213 };
2214}
2215
2216#[macro_export]
2218macro_rules! impl_min_i {
2219 ($($t:ty),+) => {
2220 $(
2221 paste::paste! {
2222 #[inline(always)]
2223 fn [<min_i_ $t>](&self, x: &[$t]) -> u64 {
2224 if x.is_empty() {
2225 panic!("Cannot find minimum index of empty vector");
2226 }
2227 x.iter()
2228 .enumerate()
2229 .min_by(|a, b| {
2230 match a.1.partial_cmp(b.1) {
2231 Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), Some(ordering) => ordering,
2233 None => std::cmp::Ordering::Equal,
2234 }
2235 })
2236 .unwrap_or((0, &x[0]))
2237 .0 as u64
2238 }
2239 }
2240 )+
2241 };
2242}
2243
2244#[macro_export]
2246macro_rules! impl_max_i {
2247 ($($t:ty),+) => {
2248 $(
2249 paste::paste! {
2250 #[inline(always)]
2251 fn [<max_i_ $t>](&self, x: &[$t]) -> u64 {
2252 if x.is_empty() {
2253 panic!("Cannot find maximum index of empty vector");
2254 }
2255 x.iter()
2256 .enumerate()
2257 .max_by(|a, b| {
2258 match a.1.partial_cmp(b.1) {
2259 Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), Some(ordering) => ordering,
2261 None => std::cmp::Ordering::Equal,
2262 }
2263 })
2264 .unwrap_or((0, &x[0]))
2265 .0 as u64
2266 }
2267 }
2268 )+
2269 };
2270}
2271
2272#[macro_export]
2274macro_rules! impl_min_i_half {
2275 () => {
2276 #[inline(always)]
2277 fn min_i_f16(&self, x: &[half::f16]) -> u64 {
2278 if x.is_empty() {
2279 panic!("Cannot find minimum index of empty vector");
2280 }
2281 x.iter()
2282 .enumerate()
2283 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2284 .unwrap_or((0, &x[0]))
2285 .0 as u64
2286 }
2287
2288 #[inline(always)]
2289 fn min_i_bf16(&self, x: &[half::bf16]) -> u64 {
2290 if x.is_empty() {
2291 panic!("Cannot find minimum index of empty vector");
2292 }
2293 x.iter()
2294 .enumerate()
2295 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2296 .unwrap_or((0, &x[0]))
2297 .0 as u64
2298 }
2299 };
2300}
2301
2302#[macro_export]
2304macro_rules! impl_max_i_half {
2305 () => {
2306 #[inline(always)]
2307 fn max_i_f16(&self, x: &[half::f16]) -> u64 {
2308 if x.is_empty() {
2309 panic!("Cannot find maximum index of empty vector");
2310 }
2311 x.iter()
2312 .enumerate()
2313 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2314 .unwrap_or((0, &x[0]))
2315 .0 as u64
2316 }
2317
2318 #[inline(always)]
2319 fn max_i_bf16(&self, x: &[half::bf16]) -> u64 {
2320 if x.is_empty() {
2321 panic!("Cannot find maximum index of empty vector");
2322 }
2323 x.iter()
2324 .enumerate()
2325 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2326 .unwrap_or((0, &x[0]))
2327 .0 as u64
2328 }
2329 };
2330}