slsl/backend/
macro.rs

1//! Macro definitions for backend operations
2//!
3//! This module contains macros that generate implementations for all supported
4//! data types and operations across different backends.
5
6// ========== BLAS Level 1 Operations ==========
7
8/// Generate dot product implementations for integer types that return larger types to avoid overflow
9#[macro_export]
10macro_rules! impl_dot {
11    ($($t:ty),+) => {
12        $(
13            paste::paste! {
14                #[inline(always)]
15                fn [<dot_ $t>](&self, a: &[$t], b: &[$t]) -> f64 {
16                    assert_eq!(a.len(), b.len(), "Vector lengths must match for dot product");
17                    a.iter().zip(b.iter()).map(|(x, y)| (*x as f64) * (*y as f64)).sum()
18                }
19            }
20        )+
21    };
22}
23
24/// Generate scale operations for all numeric types
25#[macro_export]
26macro_rules! impl_scal {
27    ($($t:ty),+) => {
28        $(
29            paste::paste! {
30                #[inline(always)]
31                fn [<scal_ $t>](&self, a: $t, x: &mut [$t]) {
32                    for xi in x.iter_mut() {
33                        *xi *= a;
34                    }
35                }
36            }
37        )+
38    };
39}
40
41/// Generate L1 norm (sum of absolute values) operations for signed numeric types
42#[macro_export]
43macro_rules! impl_asum_signed {
44    ($($t:ty),+) => {
45        $(
46            paste::paste! {
47                #[inline(always)]
48                fn [<asum_ $t>](&self, x: &[$t]) -> $t {
49                    if x.is_empty() {
50                        return 0 as $t;
51                    }
52                    x.iter().map(|xi| (*xi).abs()).sum()
53                }
54            }
55        )+
56    };
57}
58
59/// Generate L1 norm (sum of absolute values) operations for unsigned numeric types
60#[macro_export]
61macro_rules! impl_asum_unsigned {
62    ($($t:ty),+) => {
63        $(
64            paste::paste! {
65                #[inline(always)]
66                fn [<asum_ $t>](&self, x: &[$t]) -> $t {
67                    if x.is_empty() {
68                        return 0 as $t;
69                    }
70                    // For unsigned integers, the value is already non-negative
71                    x.iter().sum()
72                }
73            }
74        )+
75    };
76}
77
78/// Generate L1 norm (sum of absolute values) operations for half precision types
79#[macro_export]
80macro_rules! impl_asum_half {
81    () => {
82        #[inline(always)]
83        fn asum_f16(&self, x: &[half::f16]) -> f32 {
84            if x.is_empty() {
85                return 0.0f32;
86            }
87            x.iter().map(|xi| xi.to_f32().abs()).sum()
88        }
89
90        #[inline(always)]
91        fn asum_bf16(&self, x: &[half::bf16]) -> f32 {
92            if x.is_empty() {
93                return 0.0f32;
94            }
95            x.iter().map(|xi| xi.to_f32().abs()).sum()
96        }
97    };
98}
99
100// ========== BLAS Level 3 Operations ==========
101
102/// Generate general matrix multiplication for all numeric types
103/// Generate optimized GEMM implementations for f32 and f64 using unified gemm library
104#[macro_export]
105macro_rules! impl_gemm_sd {
106    (f32, f64) => {
107        /// Performs matrix multiplication for f32: C = A * B
108        /// Uses optimized gemm library implementation with dynamic thread configuration
109        #[inline(always)]
110        /// # Safety
111        ///
112        /// The caller must ensure that `a`, `b`, and `c` are valid pointers to arrays of the
113        /// correct size, and that `m`, `n`, `k`, `lda`, `ldb`, and `ldc` are valid dimensions
114        /// and leading dimensions for the matrices involved in the multiplication.
115        /// The matrices must not overlap in a way that would cause data races if accessed
116        /// concurrently.
117        unsafe fn gemm_f32(
118            &self,
119            m: usize,
120            n: usize,
121            k: usize,
122            a: *const f32,
123            lda: usize,
124            b: *const f32,
125            ldb: usize,
126            c: *mut f32,
127            ldc: usize,
128        ) {
129            // Get thread configuration similar to candle
130            #[cfg(feature = "rayon")]
131            let parallelism = {
132                let num_threads = $crate::get_num_threads();
133                if num_threads > 1 {
134                    gemm::Parallelism::Rayon(num_threads)
135                } else {
136                    gemm::Parallelism::None
137                }
138            };
139            #[cfg(not(feature = "rayon"))]
140            let parallelism = gemm::Parallelism::None;
141
142            // Use unified gemm library for optimized matrix multiplication
143            // Formula: dst := alpha×dst + beta×lhs×rhs
144            // For C = A * B, we want: C = 0*C + 1*A*B, so alpha=0, beta=1
145            // Note: gemm expects (dst, lhs, rhs) parameter order, so for C = A*B we pass (C, A, B)
146            // For row-major layout: row_stride = 1, col_stride = leading_dimension
147            gemm::gemm(
148                m,
149                n,
150                k,
151                c,
152                1,            // dst_rs: row stride for C matrix (row-major: 1)
153                ldc as isize, // dst_cs: column stride for C matrix (use ldc)
154                false,        // read_dst: don't read existing values in C
155                a,            // lhs: A matrix (first operand)
156                1,            // lhs_rs: row stride for A matrix (row-major: 1)
157                lda as isize, // lhs_cs: column stride for A matrix (use lda)
158                b,            // rhs: B matrix (second operand)
159                1,            // rhs_rs: row stride for B matrix (row-major: 1)
160                ldb as isize, // rhs_cs: column stride for B matrix (use ldb)
161                0.0f32,       // alpha: coefficient for existing C (0 since we overwrite)
162                1.0f32,       // beta: coefficient for A*B
163                false,
164                false,
165                false, // conj_dst, conj_lhs, conj_rhs
166                parallelism,
167            );
168        }
169
170        /// Performs matrix multiplication for f64: C = A * B
171        /// Uses optimized gemm library implementation with dynamic thread configuration
172        #[inline(always)]
173        /// # Safety
174        ///
175        /// The caller must ensure that `a`, `b`, and `c` are valid pointers to arrays of the
176        /// correct size, and that `m`, `n`, `k`, `lda`, `ldb`, and `ldc` are valid dimensions
177        /// and leading dimensions for the matrices involved in the multiplication.
178        /// The matrices must not overlap in a way that would cause data races if accessed
179        /// concurrently.
180        unsafe fn gemm_f64(
181            &self,
182            m: usize,
183            n: usize,
184            k: usize,
185            a: *const f64,
186            lda: usize,
187            b: *const f64,
188            ldb: usize,
189            c: *mut f64,
190            ldc: usize,
191        ) {
192            // Get thread configuration similar to candle
193            #[cfg(feature = "rayon")]
194            let parallelism = {
195                let num_threads = $crate::get_num_threads();
196                if num_threads > 1 {
197                    gemm::Parallelism::Rayon(num_threads)
198                } else {
199                    gemm::Parallelism::None
200                }
201            };
202            #[cfg(not(feature = "rayon"))]
203            let parallelism = gemm::Parallelism::None;
204
205            // Use unified gemm library for optimized matrix multiplication
206            // Formula: dst := alpha×dst + beta×lhs×rhs
207            // For C = A * B, we want: C = 0*C + 1*A*B, so alpha=0, beta=1
208            // Note: gemm expects (dst, lhs, rhs) parameter order, so for C = A*B we pass (C, A, B)
209            // For row-major layout: row_stride = 1, col_stride = leading_dimension
210            gemm::gemm(
211                m,
212                n,
213                k,
214                c,
215                1,            // dst_rs: row stride for C matrix (row-major: 1)
216                ldc as isize, // dst_cs: column stride for C matrix (use ldc)
217                false,        // read_dst: don't read existing values in C
218                a,            // lhs: A matrix (first operand)
219                1,            // lhs_rs: row stride for A matrix (row-major: 1)
220                lda as isize, // lhs_cs: column stride for A matrix (use lda)
221                b,            // rhs: B matrix (second operand)
222                1,            // rhs_rs: row stride for B matrix (row-major: 1)
223                ldb as isize, // rhs_cs: column stride for B matrix (use ldb)
224                0.0f64,       // alpha: coefficient for existing C (0 since we overwrite)
225                1.0f64,       // beta: coefficient for A*B
226                false,
227                false,
228                false, // conj_dst, conj_lhs, conj_rhs
229                parallelism,
230            );
231        }
232    };
233}
234
235/// Generate GEMM implementations for other numeric types (fallback to simple implementation)
236#[macro_export]
237macro_rules! impl_gemm {
238    ($($t:ty),*) => {
239        $(
240            paste::paste! {
241                /// Performs matrix multiplication C = alpha * A * B + beta * C
242                ///
243                /// # Safety
244                ///
245                /// This function is unsafe because it performs raw pointer operations.
246                /// The caller must ensure:
247                /// - All pointers are valid and point to properly allocated memory
248                /// - Matrix dimensions are correct (m, n, k)
249                /// - Leading dimensions (lda, ldb, ldc) are valid
250                /// - Memory regions do not overlap inappropriately
251                #[inline(always)]
252                unsafe fn [<gemm_ $t>](
253                        &self,
254                        m: usize,
255                        n: usize,
256                        k: usize,
257                        a: *const $t,
258                        lda: usize,
259                        b: *const $t,
260                        ldb: usize,
261                        c: *mut $t,
262                        ldc: usize,
263                    ) {
264                        // Simple matrix multiplication implementation for non-optimized types
265                        for i in 0..m {
266                            for j in 0..n {
267                                let mut sum = 0 as $t;
268                                for l in 0..k {
269                                    sum += *a.add(i * lda + l) * *b.add(l * ldb + j);
270                                }
271                                *c.add(i * ldc + j) = sum;
272                            }
273                        }
274                    }
275            }
276        )+
277    };
278}
279
280/// Generate general matrix multiplication for half precision types using unified gemm library
281#[macro_export]
282macro_rules! impl_gemm_half {
283    () => {
284        /// Performs matrix multiplication C = A * B for f16 matrices
285        /// Uses optimized gemm library implementation with dynamic thread configuration
286        ///
287        /// # Safety
288        ///
289        /// This function is unsafe because it performs raw pointer operations.
290        /// The caller must ensure:
291        /// - All pointers are valid and point to properly allocated memory
292        /// - Matrix dimensions are correct (m, n, k)
293        /// - Leading dimensions (lda, ldb, ldc) are valid
294        /// - Memory regions do not overlap inappropriately
295        #[inline(always)]
296        /// # Safety
297        ///
298        /// The caller must ensure that `a`, `b`, and `c` are valid pointers to arrays of the
299        /// correct size, and that `m`, `n`, `k`, `lda`, `ldb`, and `ldc` are valid dimensions
300        /// and leading dimensions for the matrices involved in the multiplication.
301        /// The matrices must not overlap in a way that would cause data races if accessed
302        /// concurrently.
303        unsafe fn gemm_f16(
304            &self,
305            m: usize,
306            n: usize,
307            k: usize,
308            a: *const half::f16,
309            lda: usize,
310            b: *const half::f16,
311            ldb: usize,
312            c: *mut half::f16,
313            ldc: usize,
314        ) {
315            // Get thread configuration similar to candle
316            #[cfg(feature = "rayon")]
317            let parallelism = {
318                let num_threads = $crate::get_num_threads();
319                if num_threads > 1 {
320                    gemm::Parallelism::Rayon(num_threads)
321                } else {
322                    gemm::Parallelism::None
323                }
324            };
325            #[cfg(not(feature = "rayon"))]
326            let parallelism = gemm::Parallelism::None;
327
328            // Use unified gemm library for optimized f16 matrix multiplication
329            // Formula: dst := alpha×dst + beta×lhs×rhs
330            // For C = A * B, we want: C = 0*C + 1*A*B, so alpha=0, beta=1
331            // Note: gemm expects (dst, lhs, rhs) parameter order, so for C = A*B we pass (C, A, B)
332            gemm::gemm(
333                m,
334                n,
335                k,
336                c,
337                1,               // dst_rs: row stride for C matrix (row-major: 1)
338                ldc as isize,    // dst_cs: column stride for C matrix (use ldc)
339                false,           // read_dst: don't read existing values in C
340                a,               // lhs: A matrix (first operand)
341                1,               // lhs_rs: row stride for A matrix (row-major: 1)
342                lda as isize,    // lhs_cs: column stride for A matrix (use lda)
343                b,               // rhs: B matrix (second operand)
344                1,               // rhs_rs: row stride for B matrix (row-major: 1)
345                ldb as isize,    // rhs_cs: column stride for B matrix (use ldb)
346                gemm::f16::ZERO, // alpha: coefficient for existing C (0 since we overwrite)
347                gemm::f16::ONE,  // beta: coefficient for A*B
348                false,
349                false,
350                false, // conj_dst, conj_lhs, conj_rhs
351                parallelism,
352            );
353        }
354
355        /// Performs matrix multiplication C = A * B for bf16 matrices
356        ///
357        /// # Safety
358        ///
359        /// This function is unsafe because it performs raw pointer operations.
360        /// The caller must ensure:
361        /// - All pointers are valid and point to properly allocated memory
362        /// - Matrix dimensions are correct (m, n, k)
363        /// - Leading dimensions (lda, ldb, ldc) are valid
364        /// - Memory regions do not overlap inappropriately
365        #[inline(always)]
366        unsafe fn gemm_bf16(
367            &self,
368            m: usize,
369            n: usize,
370            k: usize,
371            a: *const half::bf16,
372            lda: usize,
373            b: *const half::bf16,
374            ldb: usize,
375            c: *mut half::bf16,
376            ldc: usize,
377        ) {
378            for i in 0..m {
379                for j in 0..n {
380                    let mut sum = half::bf16::ZERO;
381                    for l in 0..k {
382                        let a_val = (*a.add(i * lda + l));
383                        let b_val = (*b.add(l * ldb + j));
384                        sum += a_val * b_val;
385                    }
386                    *c.add(i * ldc + j) = sum;
387                }
388            }
389        }
390    };
391}
392
393// ========== Vectorized Math Functions ==========
394
395/// Generate vectorized exponential function for all numeric types
396#[macro_export]
397macro_rules! impl_v_exp {
398    ($($t:ty),+) => {
399        $(
400            paste::paste! {
401                #[inline(always)]
402                fn [<v_exp_ $t>](&self, x: &[$t], out: &mut [$t]) {
403                    assert_eq!(
404                        x.len(),
405                        out.len(),
406                        "Input and output slices must have same length"
407                    );
408                    for (o, xi) in out.iter_mut().zip(x.iter()) {
409                        *o = (*xi ).exp() ;
410                    }
411                }
412            }
413        )+
414    };
415}
416
417/// Generate vectorized exponential function for f16 and bf16
418#[macro_export]
419macro_rules! impl_v_exp_half {
420    () => {
421        #[inline(always)]
422        fn v_exp_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
423            assert_eq!(
424                x.len(),
425                out.len(),
426                "Input and output slices must have same length"
427            );
428            for (o, xi) in out.iter_mut().zip(x.iter()) {
429                *o = half::f16::from_f32(xi.to_f32().exp());
430            }
431        }
432
433        #[inline(always)]
434        fn v_exp_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
435            assert_eq!(
436                x.len(),
437                out.len(),
438                "Input and output slices must have same length"
439            );
440            for (o, xi) in out.iter_mut().zip(x.iter()) {
441                *o = half::bf16::from_f32(xi.to_f32().exp());
442            }
443        }
444    };
445}
446
447/// Generate vectorized sine function for all numeric types
448#[macro_export]
449macro_rules! impl_v_sin {
450    ($($t:ty),+) => {
451        $(
452            paste::paste! {
453                #[inline(always)]
454                fn [<v_sin_ $t>](&self, x: &[$t], out: &mut [$t]) {
455                    assert_eq!(
456                        x.len(),
457                        out.len(),
458                        "Input and output slices must have same length"
459                    );
460                    for (o, xi) in out.iter_mut().zip(x.iter()) {
461                        *o = (*xi).sin() ;
462                    }
463                }
464            }
465        )+
466    };
467}
468
469/// Generate vectorized sine for f16 and bf16
470#[macro_export]
471macro_rules! impl_v_sin_half {
472    () => {
473        #[inline(always)]
474        fn v_sin_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
475            assert_eq!(
476                x.len(),
477                out.len(),
478                "Input and output slices must have same length"
479            );
480            for (o, xi) in out.iter_mut().zip(x.iter()) {
481                *o = half::f16::from_f32(xi.to_f32().sin());
482            }
483        }
484
485        #[inline(always)]
486        fn v_sin_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
487            assert_eq!(
488                x.len(),
489                out.len(),
490                "Input and output slices must have same length"
491            );
492            for (o, xi) in out.iter_mut().zip(x.iter()) {
493                *o = half::bf16::from_f32(xi.to_f32().sin());
494            }
495        }
496    };
497}
498
499/// Generate vectorized cosine function for all numeric types
500#[macro_export]
501macro_rules! impl_v_cos {
502    ($($t:ty),+) => {
503        $(
504            paste::paste! {
505                #[inline(always)]
506                fn [<v_cos_ $t>](&self, x: &[$t], out: &mut [$t]) {
507                    assert_eq!(
508                        x.len(),
509                        out.len(),
510                        "Input and output slices must have same length"
511                    );
512                    for (o, xi) in out.iter_mut().zip(x.iter()) {
513                        *o = (*xi  ).cos() ;
514                    }
515                }
516            }
517        )+
518    };
519}
520
521/// Generate vectorized cosine function for f16 and bf16
522#[macro_export]
523macro_rules! impl_v_cos_half {
524    () => {
525        #[inline(always)]
526        fn v_cos_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
527            assert_eq!(
528                x.len(),
529                out.len(),
530                "Input and output slices must have same length"
531            );
532            for (o, xi) in out.iter_mut().zip(x.iter()) {
533                *o = half::f16::from_f32(xi.to_f32().cos());
534            }
535        }
536
537        #[inline(always)]
538        fn v_cos_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
539            assert_eq!(
540                x.len(),
541                out.len(),
542                "Input and output slices must have same length"
543            );
544            for (o, xi) in out.iter_mut().zip(x.iter()) {
545                *o = half::bf16::from_f32(xi.to_f32().cos());
546            }
547        }
548    };
549}
550
551/// Generate vectorized hyperbolic tangent for all numeric types
552#[macro_export]
553macro_rules! impl_v_tanh {
554    ($($t:ty),+) => {
555        $(
556            paste::paste! {
557    #[inline(always)]
558        fn [<v_tanh_ $t>](&self, x: &[$t], out: &mut [$t]) {
559                    assert_eq!(
560                        x.len(),
561                        out.len(),
562                        "Input and output slices must have same length"
563                    );
564                    for (o, xi) in out.iter_mut().zip(x.iter()) {
565                        *o = (*xi ).tanh();
566                    }
567                }
568            }
569        )+
570    };
571}
572
573/// Generate vectorized hyperbolic tangent for f16 and bf16
574#[macro_export]
575macro_rules! impl_v_tanh_half {
576    () => {
577        fn v_tanh_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
578            assert_eq!(
579                x.len(),
580                out.len(),
581                "Input and output slices must have same length"
582            );
583            for (o, xi) in out.iter_mut().zip(x.iter()) {
584                *o = half::f16::from_f32(xi.to_f32().tanh());
585            }
586        }
587
588        fn v_tanh_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
589            assert_eq!(
590                x.len(),
591                out.len(),
592                "Input and output slices must have same length"
593            );
594            for (o, xi) in out.iter_mut().zip(x.iter()) {
595                *o = half::bf16::from_f32(xi.to_f32().tanh());
596            }
597        }
598    };
599}
600
601/// Generate vectorized natural logarithm for all numeric types
602#[macro_export]
603macro_rules! impl_v_log {
604    ($($t:ty),+) => {
605        $(
606            paste::paste! {
607    #[inline(always)]
608    fn [<v_log_ $t>](&self, x: &[$t], out: &mut [$t]) {
609                    assert_eq!(
610                        x.len(),
611                        out.len(),
612                        "Input and output slices must have same length"
613                    );
614                    for (o, xi) in out.iter_mut().zip(x.iter()) {
615                        *o = (*xi  ).ln();
616                    }
617                }
618            }
619        )+
620    };
621}
622
623/// Generate vectorized natural logarithm for f16 and bf16
624#[macro_export]
625macro_rules! impl_v_log_half {
626    () => {
627        #[inline(always)]
628        fn v_log_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
629            assert_eq!(
630                x.len(),
631                out.len(),
632                "Input and output slices must have same length"
633            );
634            for (o, xi) in out.iter_mut().zip(x.iter()) {
635                *o = half::f16::from_f32(xi.to_f32().ln());
636            }
637        }
638
639        #[inline(always)]
640        fn v_log_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
641            assert_eq!(
642                x.len(),
643                out.len(),
644                "Input and output slices must have same length"
645            );
646            for (o, xi) in out.iter_mut().zip(x.iter()) {
647                *o = half::bf16::from_f32(xi.to_f32().ln());
648            }
649        }
650    };
651}
652
653/// Generate vectorized square root for all numeric types
654#[macro_export]
655macro_rules! impl_v_sqrt {
656    ($($t:ty),+) => {
657        $(
658            paste::paste! {
659    #[inline(always)]
660    fn [<v_sqrt_ $t>](&self, x: &[$t], out: &mut [$t]) {
661                    assert_eq!(
662                        x.len(),
663                        out.len(),
664                        "Input and output slices must have same length"
665                    );
666                    for (o, xi) in out.iter_mut().zip(x.iter()) {
667                        *o = (*xi ).sqrt() ;
668                    }
669                }
670            }
671        )+
672    };
673}
674
675/// Generate vectorized square root for f16 and bf16
676#[macro_export]
677macro_rules! impl_v_sqrt_half {
678    () => {
679        #[inline(always)]
680        fn v_sqrt_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
681            assert_eq!(
682                x.len(),
683                out.len(),
684                "Input and output slices must have same length"
685            );
686            for (o, xi) in out.iter_mut().zip(x.iter()) {
687                *o = half::f16::from_f32(xi.to_f32().sqrt());
688            }
689        }
690
691        #[inline(always)]
692        fn v_sqrt_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
693            assert_eq!(
694                x.len(),
695                out.len(),
696                "Input and output slices must have same length"
697            );
698            for (o, xi) in out.iter_mut().zip(x.iter()) {
699                *o = half::bf16::from_f32(xi.to_f32().sqrt());
700            }
701        }
702    };
703}
704
705/// Generate vectorized square (element-wise) for all numeric types
706#[macro_export]
707macro_rules! impl_v_sqr {
708    ($($t:ty),+) => {
709        $(
710            paste::paste! {
711    #[inline(always)]
712    fn [<v_sqr_ $t>](&self, x: &[$t], out: &mut [$t]) {
713                    assert_eq!(
714                        x.len(),
715                        out.len(),
716                        "Input and output slices must have same length"
717                    );
718                    for (o, xi) in out.iter_mut().zip(x.iter()) {
719                        *o = (*xi) * (*xi);
720                    }
721                }
722            }
723        )+
724    };
725}
726
727/// Generate vectorized square (element-wise) for f16 and bf16
728#[macro_export]
729macro_rules! impl_v_sqr_half {
730    () => {
731        #[inline(always)]
732        fn v_sqr_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
733            assert_eq!(
734                x.len(),
735                out.len(),
736                "Input and output slices must have same length"
737            );
738            for (o, xi) in out.iter_mut().zip(x.iter()) {
739                *o = xi * xi;
740            }
741        }
742
743        #[inline(always)]
744        fn v_sqr_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
745            assert_eq!(
746                x.len(),
747                out.len(),
748                "Input and output slices must have same length"
749            );
750            for (o, xi) in out.iter_mut().zip(x.iter()) {
751                *o = xi * xi;
752            }
753        }
754    };
755}
756
757/// Generate vectorized element-wise addition for all numeric types
758#[macro_export]
759macro_rules! impl_v_add {
760    ($($t:ty),+) => {
761        $(
762            paste::paste! {
763    #[inline(always)]
764    fn [<v_add_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
765                    assert_eq!(a.len(), b.len(), "Input slices must have same length");
766                    assert_eq!(
767                        a.len(),
768                        out.len(),
769                        "Input and output slices must have same length"
770                    );
771                    for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
772                        *o = *ai + *bi;
773                    }
774                }
775            }
776        )+
777    };
778}
779
780/// Generate vectorized element-wise addition for half precision types
781#[macro_export]
782macro_rules! impl_v_add_half {
783    () => {
784        #[inline(always)]
785        fn v_add_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
786            assert_eq!(a.len(), b.len(), "Input slices must have same length");
787            assert_eq!(
788                a.len(),
789                out.len(),
790                "Input and output slices must have same length"
791            );
792            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
793                *o = ai + bi;
794            }
795        }
796
797        #[inline(always)]
798        fn v_add_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
799            assert_eq!(a.len(), b.len(), "Input slices must have same length");
800            assert_eq!(
801                a.len(),
802                out.len(),
803                "Input and output slices must have same length"
804            );
805            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
806                *o = ai + bi;
807            }
808        }
809    };
810}
811
812/// Generate vectorized element-wise subtraction for all numeric types
813#[macro_export]
814macro_rules! impl_v_sub {
815    ($($t:ty),+) => {
816        $(
817            paste::paste! {
818    #[inline(always)]
819    fn [<v_sub_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
820                    assert_eq!(a.len(), b.len(), "Input slices must have same length");
821                    assert_eq!(
822                        a.len(),
823                        out.len(),
824                        "Input and output slices must have same length"
825                    );
826                    for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
827                        *o = *ai - *bi;
828                    }
829                }
830            }
831        )+
832    };
833}
834
835/// Generate vectorized element-wise subtraction for half precision types
836#[macro_export]
837macro_rules! impl_v_sub_half {
838    () => {
839        #[inline(always)]
840        fn v_sub_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
841            assert_eq!(a.len(), b.len(), "Input slices must have same length");
842            assert_eq!(
843                a.len(),
844                out.len(),
845                "Input and output slices must have same length"
846            );
847            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
848                *o = ai - bi;
849            }
850        }
851
852        #[inline(always)]
853        fn v_sub_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
854            assert_eq!(a.len(), b.len(), "Input slices must have same length");
855            assert_eq!(
856                a.len(),
857                out.len(),
858                "Input and output slices must have same length"
859            );
860            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
861                *o = ai - bi;
862            }
863        }
864    };
865}
866
867/// Generate vectorized element-wise multiplication for all numeric types
868#[macro_export]
869macro_rules! impl_v_mul {
870    ($($t:ty),+) => {
871        $(
872            paste::paste! {
873    #[inline(always)]
874    fn [<v_mul_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
875                    assert_eq!(a.len(), b.len(), "Input slices must have same length");
876                    assert_eq!(
877                        a.len(),
878                        out.len(),
879                        "Input and output slices must have same length"
880                    );
881                    for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
882                        *o = *ai * *bi;
883                    }
884                }
885            }
886        )+
887    };
888}
889
890/// Generate vectorized element-wise multiplication for half precision types
891#[macro_export]
892macro_rules! impl_v_mul_half {
893    () => {
894        #[inline(always)]
895        fn v_mul_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
896            assert_eq!(a.len(), b.len(), "Input slices must have same length");
897            assert_eq!(
898                a.len(),
899                out.len(),
900                "Input and output slices must have same length"
901            );
902            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
903                *o = (*ai) * (*bi);
904            }
905        }
906
907        #[inline(always)]
908        fn v_mul_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
909            assert_eq!(a.len(), b.len(), "Input slices must have same length");
910            assert_eq!(
911                a.len(),
912                out.len(),
913                "Input and output slices must have same length"
914            );
915            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
916                *o = (*ai) * (*bi);
917            }
918        }
919    };
920}
921
922/// Generate vectorized element-wise division for all numeric types
923#[macro_export]
924macro_rules! impl_v_div {
925    ($($t:ty),+) => {
926        $(
927            paste::paste! {
928    #[inline(always)]
929    fn [<v_div_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
930                    assert_eq!(a.len(), b.len(), "Input slices must have same length");
931                    assert_eq!(
932                        a.len(),
933                        out.len(),
934                        "Input and output slices must have same length"
935                    );
936                    for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
937                        *o = *ai / *bi;
938                    }
939                }
940            }
941        )+
942    };
943}
944
945/// Generate vectorized element-wise division for half precision types
946#[macro_export]
947macro_rules! impl_v_div_half {
948    () => {
949        #[inline(always)]
950        fn v_div_f16(&self, a: &[half::f16], b: &[half::f16], out: &mut [half::f16]) {
951            assert_eq!(a.len(), b.len(), "Input slices must have same length");
952            assert_eq!(
953                a.len(),
954                out.len(),
955                "Input and output slices must have same length"
956            );
957            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
958                *o = (*ai) / (*bi);
959            }
960        }
961
962        #[inline(always)]
963        fn v_div_bf16(&self, a: &[half::bf16], b: &[half::bf16], out: &mut [half::bf16]) {
964            assert_eq!(a.len(), b.len(), "Input slices must have same length");
965            assert_eq!(
966                a.len(),
967                out.len(),
968                "Input and output slices must have same length"
969            );
970            for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
971                *o = (*ai) / (*bi);
972            }
973        }
974    };
975}
976
977/// Generate scalar division operations for all numeric types
978#[macro_export]
979macro_rules! impl_v_div_scalar {
980    ($($t:ty),+) => {
981        $(
982            paste::paste! {
983                #[inline(always)]
984                fn [<v_div_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
985                    assert_eq!(
986                        x.len(),
987                        out.len(),
988                        "Input and output slices must have same length"
989                    );
990                    for (o, xi) in out.iter_mut().zip(x.iter()) {
991                        *o = *xi / scalar;
992                    }
993                }
994            }
995        )+
996    };
997}
998
999/// Generate scalar division operations for half precision types
1000#[macro_export]
1001macro_rules! impl_v_div_scalar_half {
1002    () => {
1003        #[inline(always)]
1004        fn v_div_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1005            assert_eq!(
1006                x.len(),
1007                out.len(),
1008                "Input and output slices must have same length"
1009            );
1010            for (o, xi) in out.iter_mut().zip(x.iter()) {
1011                *o = (*xi) / scalar;
1012            }
1013        }
1014
1015        #[inline(always)]
1016        fn v_div_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1017            assert_eq!(
1018                x.len(),
1019                out.len(),
1020                "Input and output slices must have same length"
1021            );
1022            for (o, xi) in out.iter_mut().zip(x.iter()) {
1023                *o = (*xi) / scalar;
1024            }
1025        }
1026    };
1027}
1028
1029/// Generate vectorized tangent function for all numeric types
1030#[macro_export]
1031macro_rules! impl_v_tan {
1032    ($($t:ty),+) => {
1033        $(
1034            paste::paste! {
1035                #[inline(always)]
1036                fn [<v_tan_ $t>](&self, x: &[$t], out: &mut [$t]) {
1037                    assert_eq!(
1038                        x.len(),
1039                        out.len(),
1040                        "Input and output slices must have same length"
1041                    );
1042                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1043                        *o = (*xi ).tan() ;
1044                    }
1045                }
1046            }
1047        )+
1048    };
1049}
1050
1051/// Generate vectorized tangent function for f16 and bf16
1052#[macro_export]
1053macro_rules! impl_v_tan_half {
1054    () => {
1055        #[inline(always)]
1056        fn v_tan_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1057            assert_eq!(
1058                x.len(),
1059                out.len(),
1060                "Input and output slices must have same length"
1061            );
1062            for (o, xi) in out.iter_mut().zip(x.iter()) {
1063                *o = half::f16::from_f32(xi.to_f32().tan());
1064            }
1065        }
1066
1067        #[inline(always)]
1068        fn v_tan_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1069            assert_eq!(
1070                x.len(),
1071                out.len(),
1072                "Input and output slices must have same length"
1073            );
1074            for (o, xi) in out.iter_mut().zip(x.iter()) {
1075                *o = half::bf16::from_f32(xi.to_f32().tan());
1076            }
1077        }
1078    };
1079}
1080
1081/// Generate vectorized reciprocal function for all numeric types
1082#[macro_export]
1083macro_rules! impl_v_recip {
1084    ($($t:ty),+) => {
1085        $(
1086            paste::paste! {
1087    #[inline(always)]
1088    fn [<v_recip_ $t>](&self, x: &[$t], out: &mut [$t]) {
1089                    assert_eq!(
1090                        x.len(),
1091                        out.len(),
1092                        "Input and output slices must have same length"
1093                    );
1094                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1095                        *o = (1.0 / (*xi ));
1096                    }
1097                }
1098            }
1099        )+
1100    };
1101}
1102
1103/// Generate vectorized reciprocal function for f16 and bf16
1104#[macro_export]
1105macro_rules! impl_v_recip_half {
1106    () => {
1107        #[inline(always)]
1108        fn v_recip_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1109            assert_eq!(
1110                x.len(),
1111                out.len(),
1112                "Input and output slices must have same length"
1113            );
1114            for (o, xi) in out.iter_mut().zip(x.iter()) {
1115                *o = half::f16::ONE / xi;
1116            }
1117        }
1118
1119        #[inline(always)]
1120        fn v_recip_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1121            assert_eq!(
1122                x.len(),
1123                out.len(),
1124                "Input and output slices must have same length"
1125            );
1126            for (o, xi) in out.iter_mut().zip(x.iter()) {
1127                *o = half::bf16::ONE / xi;
1128            }
1129        }
1130    };
1131}
1132
1133/// Generate vectorized floor function for all numeric types
1134#[macro_export]
1135macro_rules! impl_v_floor {
1136    ($($t:ty),+) => {
1137        $(
1138            paste::paste! {
1139    #[inline(always)]
1140    fn [<v_floor_ $t>](&self, x: &[$t], out: &mut [$t]) {
1141                    assert_eq!(
1142                        x.len(),
1143                        out.len(),
1144                        "Input and output slices must have same length"
1145                    );
1146                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1147                        *o = (*xi ).floor() ;
1148                    }
1149                }
1150            }
1151        )+
1152    };
1153}
1154
1155/// Generate vectorized floor function for f16 and bf16
1156#[macro_export]
1157macro_rules! impl_v_floor_half {
1158    () => {
1159        #[inline(always)]
1160        fn v_floor_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1161            assert_eq!(
1162                x.len(),
1163                out.len(),
1164                "Input and output slices must have same length"
1165            );
1166            for (o, xi) in out.iter_mut().zip(x.iter()) {
1167                *o = half::f16::from_f32(xi.to_f32().floor());
1168            }
1169        }
1170
1171        #[inline(always)]
1172        fn v_floor_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1173            assert_eq!(
1174                x.len(),
1175                out.len(),
1176                "Input and output slices must have same length"
1177            );
1178            for (o, xi) in out.iter_mut().zip(x.iter()) {
1179                *o = half::bf16::from_f32(xi.to_f32().floor());
1180            }
1181        }
1182    };
1183}
1184
1185/// Generate vectorized ceiling function for all numeric types
1186#[macro_export]
1187macro_rules! impl_v_ceil {
1188    ($($t:ty),+) => {
1189        $(
1190            paste::paste! {
1191    #[inline(always)]
1192    fn [<v_ceil_ $t>](&self, x: &[$t], out: &mut [$t]) {
1193                    assert_eq!(
1194                        x.len(),
1195                        out.len(),
1196                        "Input and output slices must have same length"
1197                    );
1198                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1199                        *o = (*xi  ).ceil();
1200                    }
1201                }
1202            }
1203        )+
1204    };
1205}
1206
1207/// Generate vectorized ceiling function for f16 and bf16
1208#[macro_export]
1209macro_rules! impl_v_ceil_half {
1210    () => {
1211        #[inline(always)]
1212        fn v_ceil_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1213            assert_eq!(
1214                x.len(),
1215                out.len(),
1216                "Input and output slices must have same length"
1217            );
1218            for (o, xi) in out.iter_mut().zip(x.iter()) {
1219                *o = half::f16::from_f32(xi.to_f32().ceil());
1220            }
1221        }
1222
1223        #[inline(always)]
1224        fn v_ceil_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1225            assert_eq!(
1226                x.len(),
1227                out.len(),
1228                "Input and output slices must have same length"
1229            );
1230            for (o, xi) in out.iter_mut().zip(x.iter()) {
1231                *o = half::bf16::from_f32(xi.to_f32().ceil());
1232            }
1233        }
1234    };
1235}
1236
1237/// Generate vectorized round function for all numeric types
1238#[macro_export]
1239macro_rules! impl_v_round {
1240    ($($t:ty),+) => {
1241        $(
1242            paste::paste! {
1243    #[inline(always)]
1244    fn [<v_round_ $t>](&self, x: &[$t], out: &mut [$t]) {
1245                    assert_eq!(
1246                        x.len(),
1247                        out.len(),
1248                        "Input and output slices must have same length"
1249                    );
1250                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1251                        *o = (*xi ).round() ;
1252                    }
1253                }
1254            }
1255        )+
1256    };
1257}
1258
1259/// Generate vectorized round function for f16 and bf16
1260#[macro_export]
1261macro_rules! impl_v_round_half {
1262    () => {
1263        #[inline(always)]
1264        fn v_round_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1265            assert_eq!(
1266                x.len(),
1267                out.len(),
1268                "Input and output slices must have same length"
1269            );
1270            for (o, xi) in out.iter_mut().zip(x.iter()) {
1271                *o = half::f16::from_f32(xi.to_f32().round());
1272            }
1273        }
1274
1275        #[inline(always)]
1276        fn v_round_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1277            assert_eq!(
1278                x.len(),
1279                out.len(),
1280                "Input and output slices must have same length"
1281            );
1282            for (o, xi) in out.iter_mut().zip(x.iter()) {
1283                *o = half::bf16::from_f32(xi.to_f32().round());
1284            }
1285        }
1286    };
1287}
1288
1289/// Generate vectorized absolute value for all numeric types
1290#[macro_export]
1291macro_rules! impl_v_abs {
1292    ($($t:ty),+) => {
1293        $(
1294            paste::paste! {
1295                #[inline(always)]
1296                fn [<v_abs_ $t>](&self, x: &[$t], out: &mut [$t]) {
1297                    assert_eq!(
1298                        x.len(),
1299                        out.len(),
1300                        "Input and output slices must have same length"
1301                    );
1302                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1303                        *o = (*xi).abs();
1304                    }
1305                }
1306            }
1307        )+
1308    };
1309}
1310
1311/// Generate vectorized absolute value for f16 and bf16
1312#[macro_export]
1313macro_rules! impl_v_abs_half {
1314    () => {
1315        #[inline(always)]
1316        fn v_abs_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1317            assert_eq!(
1318                x.len(),
1319                out.len(),
1320                "Input and output slices must have same length"
1321            );
1322            for (o, xi) in out.iter_mut().zip(x.iter()) {
1323                *o = half::f16::from_f32(xi.to_f32().abs());
1324            }
1325        }
1326
1327        #[inline(always)]
1328        fn v_abs_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1329            assert_eq!(
1330                x.len(),
1331                out.len(),
1332                "Input and output slices must have same length"
1333            );
1334            for (o, xi) in out.iter_mut().zip(x.iter()) {
1335                *o = half::bf16::from_f32(xi.to_f32().abs());
1336            }
1337        }
1338    };
1339}
1340
1341/// Generate vectorized negation for signed numeric types
1342#[macro_export]
1343macro_rules! impl_v_neg {
1344    ($($t:ty),+) => {
1345        $(
1346            paste::paste! {
1347    #[inline(always)]
1348    fn [<v_neg_ $t>](&self, x: &[$t], out: &mut [$t]) {
1349                    assert_eq!(
1350                        x.len(),
1351                        out.len(),
1352                        "Input and output slices must have same length"
1353                    );
1354                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1355                        *o = -(*xi);
1356                    }
1357                }
1358            }
1359        )+
1360    };
1361}
1362
1363/// Generate vectorized negation for f16 and bf16
1364#[macro_export]
1365macro_rules! impl_v_neg_half {
1366    () => {
1367        #[inline(always)]
1368        fn v_neg_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1369            assert_eq!(
1370                x.len(),
1371                out.len(),
1372                "Input and output slices must have same length"
1373            );
1374            for (o, xi) in out.iter_mut().zip(x.iter()) {
1375                *o = -(*xi);
1376            }
1377        }
1378
1379        #[inline(always)]
1380        fn v_neg_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1381            assert_eq!(
1382                x.len(),
1383                out.len(),
1384                "Input and output slices must have same length"
1385            );
1386            for (o, xi) in out.iter_mut().zip(x.iter()) {
1387                *o = -(*xi);
1388            }
1389        }
1390    };
1391}
1392
1393/// Generate vectorized power function for all numeric types
1394#[macro_export]
1395macro_rules! impl_v_pow {
1396    ($($t:ty),+) => {
1397        $(
1398            paste::paste! {
1399                #[inline(always)]
1400                fn [<v_pow_ $t>](&self, a: &[$t], b: &[$t], out: &mut [$t]) {
1401                    assert_eq!(a.len(), b.len(), "Input slices must have same length");
1402                    assert_eq!(
1403                        a.len(),
1404                        out.len(),
1405                        "Input and output slices must have same length"
1406                    );
1407                    for ((o, ai), bi) in out.iter_mut().zip(a.iter()).zip(b.iter()) {
1408                        *o = ((*ai ).powf(*bi )) ;
1409                    }
1410                }
1411            }
1412        )+
1413    };
1414}
1415
1416/// Generate vectorized ReLU operation for floating point types
1417#[macro_export]
1418macro_rules! impl_relu {
1419    ($($t:ty),+) => {
1420        $(
1421            paste::paste! {
1422                #[inline(always)]
1423                fn [<relu_ $t>](&self, x: &[$t], out: &mut [$t]) {
1424                    assert_eq!(
1425                        x.len(),
1426                        out.len(),
1427                        "Input and output slices must have same length"
1428                    );
1429                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1430                        *o = (*xi).max(0.0);
1431                    }
1432                }
1433            }
1434        )+
1435    };
1436}
1437
1438/// Generate vectorized ReLU operation for integer types
1439#[macro_export]
1440macro_rules! impl_relu_int {
1441    ($($t:ty),+) => {
1442        $(
1443            paste::paste! {
1444                #[inline(always)]
1445                fn [<relu_ $t>](&self, x: &[$t], out: &mut [$t]) {
1446                    assert_eq!(
1447                        x.len(),
1448                        out.len(),
1449                        "Input and output slices must have same length"
1450                    );
1451                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1452                        *o = (*xi).max(0);
1453                    }
1454                }
1455            }
1456        )+
1457    };
1458}
1459
1460/// Generate vectorized ReLU operation for unsigned integer types
1461#[macro_export]
1462macro_rules! impl_relu_uint {
1463    ($($t:ty),+) => {
1464        $(
1465            paste::paste! {
1466                #[inline(always)]
1467                fn [<relu_ $t>](&self, x: &[$t], out: &mut [$t]) {
1468                    assert_eq!(
1469                        x.len(),
1470                        out.len(),
1471                        "Input and output slices must have same length"
1472                    );
1473                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1474                        *o = *xi;
1475                    }
1476                }
1477            }
1478        )+
1479    };
1480}
1481
1482/// Generate vectorized ReLU operation for half precision types
1483#[macro_export]
1484macro_rules! impl_relu_half {
1485    () => {
1486        #[inline(always)]
1487        fn relu_f16(&self, x: &[half::f16], out: &mut [half::f16]) {
1488            assert_eq!(
1489                x.len(),
1490                out.len(),
1491                "Input and output slices must have same length"
1492            );
1493            for (o, xi) in out.iter_mut().zip(x.iter()) {
1494                *o = xi.max(half::f16::ZERO);
1495            }
1496        }
1497
1498        #[inline(always)]
1499        fn relu_bf16(&self, x: &[half::bf16], out: &mut [half::bf16]) {
1500            assert_eq!(
1501                x.len(),
1502                out.len(),
1503                "Input and output slices must have same length"
1504            );
1505            for (o, xi) in out.iter_mut().zip(x.iter()) {
1506                *o = xi.max(half::bf16::ZERO);
1507            }
1508        }
1509    };
1510}
1511
1512/// Generate sum operations for integer types
1513#[macro_export]
1514macro_rules! impl_sum_int {
1515    ($($t:ty => $acc:ty),+) => {
1516        $(
1517            paste::paste! {
1518                #[inline(always)]
1519                fn [<sum_ $t>](&self, x: &[$t]) -> f64 {
1520                    if x.is_empty() {
1521                        return 0.0;
1522                    }
1523                    let mut sum: $acc = 0;
1524                    for &val in x {
1525                        sum += val as $acc;
1526                    }
1527                    sum as f64
1528                }
1529            }
1530        )+
1531    };
1532}
1533
1534/// Generate mean operations for floating point types
1535#[macro_export]
1536macro_rules! impl_mean_float {
1537    ($($t:ty),+) => {
1538        $(
1539            paste::paste! {
1540                #[inline(always)]
1541                fn [<mean_ $t>](&self, x: &[$t]) -> $t {
1542                    if x.is_empty() {
1543                        return 0 as $t;
1544                    }
1545                    let sum = self.[<sum_ $t>](x);
1546                    sum / (x.len() as $t)
1547                }
1548            }
1549        )+
1550    };
1551}
1552
1553/// Generate mean operations for half precision types
1554#[macro_export]
1555macro_rules! impl_mean_half {
1556    () => {
1557        #[inline(always)]
1558        fn mean_f16(&self, x: &[half::f16]) -> f64 {
1559            if x.is_empty() {
1560                return 0.0f64;
1561            }
1562            let sum = self.sum_f16(x);
1563            sum / (x.len() as f64)
1564        }
1565
1566        #[inline(always)]
1567        fn mean_bf16(&self, x: &[half::bf16]) -> f64 {
1568            if x.is_empty() {
1569                return 0.0f64;
1570            }
1571            let sum = self.sum_bf16(x);
1572            sum / (x.len() as f64)
1573        }
1574    };
1575}
1576
1577/// Generate mean operations for integer types
1578#[macro_export]
1579macro_rules! impl_mean_int {
1580    ($($t:ty),+) => {
1581        $(
1582            paste::paste! {
1583                #[inline(always)]
1584                fn [<mean_ $t>](&self, x: &[$t]) -> f64 {
1585                    if x.is_empty() {
1586                        return 0.0f64;
1587                    }
1588                    let sum = self.[<sum_ $t>](x);
1589                    sum / (x.len() as f64)
1590                }
1591            }
1592        )+
1593    };
1594}
1595
1596/// Generate scalar addition operations for all numeric types
1597#[macro_export]
1598macro_rules! impl_v_add_scalar {
1599    ($($t:ty),+) => {
1600        $(
1601            paste::paste! {
1602                #[inline(always)]
1603                fn [<v_add_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
1604                    assert_eq!(
1605                        x.len(),
1606                        out.len(),
1607                        "Input and output slices must have same length"
1608                    );
1609                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1610                        *o = *xi + scalar;
1611                    }
1612                }
1613            }
1614        )+
1615    };
1616}
1617
1618/// Generate scalar addition operations for half precision types
1619#[macro_export]
1620macro_rules! impl_v_add_scalar_half {
1621    () => {
1622        #[inline(always)]
1623        fn v_add_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1624            assert_eq!(
1625                x.len(),
1626                out.len(),
1627                "Input and output slices must have same length"
1628            );
1629            for (o, xi) in out.iter_mut().zip(x.iter()) {
1630                *o = *xi + scalar;
1631            }
1632        }
1633
1634        #[inline(always)]
1635        fn v_add_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1636            assert_eq!(
1637                x.len(),
1638                out.len(),
1639                "Input and output slices must have same length"
1640            );
1641            for (o, xi) in out.iter_mut().zip(x.iter()) {
1642                *o = *xi + scalar;
1643            }
1644        }
1645    };
1646}
1647
1648/// Generate scalar subtraction operations for all numeric types
1649#[macro_export]
1650macro_rules! impl_v_sub_scalar {
1651    ($($t:ty),+) => {
1652        $(
1653            paste::paste! {
1654                #[inline(always)]
1655                fn [<v_sub_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
1656                    assert_eq!(
1657                        x.len(),
1658                        out.len(),
1659                        "Input and output slices must have same length"
1660                    );
1661                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1662                        *o = *xi - scalar;
1663                    }
1664                }
1665            }
1666        )+
1667    };
1668}
1669
1670/// Generate scalar subtraction operations for half precision types
1671#[macro_export]
1672macro_rules! impl_v_sub_scalar_half {
1673    () => {
1674        #[inline(always)]
1675        fn v_sub_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1676            assert_eq!(
1677                x.len(),
1678                out.len(),
1679                "Input and output slices must have same length"
1680            );
1681            for (o, xi) in out.iter_mut().zip(x.iter()) {
1682                *o = *xi - scalar;
1683            }
1684        }
1685
1686        #[inline(always)]
1687        fn v_sub_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1688            assert_eq!(
1689                x.len(),
1690                out.len(),
1691                "Input and output slices must have same length"
1692            );
1693            for (o, xi) in out.iter_mut().zip(x.iter()) {
1694                *o = *xi - scalar;
1695            }
1696        }
1697    };
1698}
1699
1700/// Generate scalar multiplication operations for all numeric types
1701#[macro_export]
1702macro_rules! impl_v_mul_scalar {
1703    ($($t:ty),+) => {
1704        $(
1705            paste::paste! {
1706                #[inline(always)]
1707                fn [<v_mul_scalar_ $t>](&self, x: &[$t], scalar: $t, out: &mut [$t]) {
1708                    assert_eq!(
1709                        x.len(),
1710                        out.len(),
1711                        "Input and output slices must have same length"
1712                    );
1713                    for (o, xi) in out.iter_mut().zip(x.iter()) {
1714                        *o = *xi * scalar;
1715                    }
1716                }
1717            }
1718        )+
1719    };
1720}
1721
1722/// Generate scalar multiplication operations for half precision types
1723#[macro_export]
1724macro_rules! impl_v_mul_scalar_half {
1725    () => {
1726        #[inline(always)]
1727        fn v_mul_scalar_f16(&self, x: &[half::f16], scalar: half::f16, out: &mut [half::f16]) {
1728            assert_eq!(
1729                x.len(),
1730                out.len(),
1731                "Input and output slices must have same length"
1732            );
1733            for (o, xi) in out.iter_mut().zip(x.iter()) {
1734                *o = *xi * scalar;
1735            }
1736        }
1737
1738        #[inline(always)]
1739        fn v_mul_scalar_bf16(&self, x: &[half::bf16], scalar: half::bf16, out: &mut [half::bf16]) {
1740            assert_eq!(
1741                x.len(),
1742                out.len(),
1743                "Input and output slices must have same length"
1744            );
1745            for (o, xi) in out.iter_mut().zip(x.iter()) {
1746                *o = *xi * scalar;
1747            }
1748        }
1749    };
1750}
1751
1752/// Generate vectorized maximum value for all numeric types
1753#[macro_export]
1754macro_rules! impl_max_v {
1755    ($($t:ty),+) => {
1756        $(
1757            paste::paste! {
1758                #[inline(always)]
1759                fn [<max_v_ $t>](&self, x: &[$t]) -> $t {
1760                    if x.is_empty() {
1761                        panic!("Cannot find maximum of empty vector");
1762                    }
1763                    *x.iter().max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0])
1764                }
1765            }
1766        )+
1767    };
1768}
1769
1770/// Generate vectorized minimum value for all numeric types
1771#[macro_export]
1772macro_rules! impl_min_v {
1773    ($($t:ty),+) => {
1774        $(
1775            paste::paste! {
1776                #[inline(always)]
1777                fn [<min_v_ $t>](&self, x: &[$t]) -> $t {
1778                    if x.is_empty() {
1779                        panic!("Cannot find minimum of empty vector");
1780                    }
1781                    *x.iter().min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0])
1782                }
1783            }
1784        )+
1785    };
1786}
1787
1788/// Generate vectorized maximum value with index for all numeric types
1789#[macro_export]
1790macro_rules! impl_max_vi {
1791    ($($t:ty),+) => {
1792        $(
1793            paste::paste! {
1794                #[inline(always)]
1795                fn [<max_vi_ $t>](&self, x: &[$t]) -> ($t, u64) {
1796                    if x.is_empty() {
1797                        panic!("Cannot find maximum of empty vector");
1798                    }
1799                    let (idx, val) = x.iter()
1800                        .enumerate()
1801                        .max_by(|a, b| {
1802                            match a.1.partial_cmp(b.1) {
1803                                Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), // Reverse index order for equal values
1804                                Some(ordering) => ordering,
1805                                None => std::cmp::Ordering::Equal,
1806                            }
1807                        })
1808                        .unwrap_or((0, &x[0]));
1809                    (*val, idx as u64)
1810                }
1811            }
1812        )+
1813    };
1814}
1815
1816/// Generate vectorized minimum value with index for all numeric types
1817#[macro_export]
1818macro_rules! impl_min_vi {
1819    ($($t:ty),+) => {
1820        $(
1821            paste::paste! {
1822                #[inline(always)]
1823                fn [<min_vi_ $t>](&self, x: &[$t]) -> ($t, u64) {
1824                    if x.is_empty() {
1825                        panic!("Cannot find minimum of empty vector");
1826                    }
1827                    let (idx, val) = x.iter()
1828                        .enumerate()
1829                        .min_by(|a, b| {
1830                            match a.1.partial_cmp(b.1) {
1831                                Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), // Keep index order for equal values
1832                                Some(ordering) => ordering,
1833                                None => std::cmp::Ordering::Equal,
1834                            }
1835                        })
1836                        .unwrap_or((0, &x[0]));
1837                    (*val, idx as u64)
1838                }
1839            }
1840        )+
1841    };
1842}
1843
1844/// Generate vectorized maximum value for half precision types
1845#[macro_export]
1846macro_rules! impl_max_v_half {
1847    () => {
1848        #[inline(always)]
1849        fn max_v_f16(&self, x: &[half::f16]) -> half::f16 {
1850            if x.is_empty() {
1851                panic!("Cannot find maximum of empty vector");
1852            }
1853            *x.iter()
1854                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1855                .unwrap_or(&x[0])
1856        }
1857
1858        #[inline(always)]
1859        fn max_v_bf16(&self, x: &[half::bf16]) -> half::bf16 {
1860            if x.is_empty() {
1861                panic!("Cannot find maximum of empty vector");
1862            }
1863            *x.iter()
1864                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1865                .unwrap_or(&x[0])
1866        }
1867    };
1868}
1869
1870/// Generate vectorized minimum value for half precision types
1871#[macro_export]
1872macro_rules! impl_min_v_half {
1873    () => {
1874        #[inline(always)]
1875        fn min_v_f16(&self, x: &[half::f16]) -> half::f16 {
1876            if x.is_empty() {
1877                panic!("Cannot find minimum of empty vector");
1878            }
1879            *x.iter()
1880                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1881                .unwrap_or(&x[0])
1882        }
1883
1884        #[inline(always)]
1885        fn min_v_bf16(&self, x: &[half::bf16]) -> half::bf16 {
1886            if x.is_empty() {
1887                panic!("Cannot find minimum of empty vector");
1888            }
1889            *x.iter()
1890                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1891                .unwrap_or(&x[0])
1892        }
1893    };
1894}
1895
1896/// Generate vectorized maximum value with index for half precision types
1897#[macro_export]
1898macro_rules! impl_max_vi_half {
1899    () => {
1900        #[inline(always)]
1901        fn max_vi_f16(&self, x: &[half::f16]) -> (half::f16, u64) {
1902            if x.is_empty() {
1903                panic!("Cannot find maximum of empty vector");
1904            }
1905            let (idx, val) = x
1906                .iter()
1907                .enumerate()
1908                .max_by(|a, b| {
1909                    match a.1.partial_cmp(b.1) {
1910                        Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), // Reverse index order for equal values
1911                        Some(ordering) => ordering,
1912                        None => std::cmp::Ordering::Equal,
1913                    }
1914                })
1915                .unwrap_or((0, &x[0]));
1916            (*val, idx as u64)
1917        }
1918
1919        #[inline(always)]
1920        fn max_vi_bf16(&self, x: &[half::bf16]) -> (half::bf16, u64) {
1921            if x.is_empty() {
1922                panic!("Cannot find maximum of empty vector");
1923            }
1924            let (idx, val) = x
1925                .iter()
1926                .enumerate()
1927                .max_by(|a, b| {
1928                    match a.1.partial_cmp(b.1) {
1929                        Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), // Reverse index order for equal values
1930                        Some(ordering) => ordering,
1931                        None => std::cmp::Ordering::Equal,
1932                    }
1933                })
1934                .unwrap_or((0, &x[0]));
1935            (*val, idx as u64)
1936        }
1937    };
1938}
1939
1940/// Generate vectorized minimum value with index for half precision types
1941#[macro_export]
1942macro_rules! impl_min_vi_half {
1943    () => {
1944        #[inline(always)]
1945        fn min_vi_f16(&self, x: &[half::f16]) -> (half::f16, u64) {
1946            if x.is_empty() {
1947                panic!("Cannot find minimum of empty vector");
1948            }
1949            let (idx, val) = x
1950                .iter()
1951                .enumerate()
1952                .min_by(|a, b| {
1953                    match a.1.partial_cmp(b.1) {
1954                        Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), // Keep index order for equal values
1955                        Some(ordering) => ordering,
1956                        None => std::cmp::Ordering::Equal,
1957                    }
1958                })
1959                .unwrap_or((0, &x[0]));
1960            (*val, idx as u64)
1961        }
1962
1963        #[inline(always)]
1964        fn min_vi_bf16(&self, x: &[half::bf16]) -> (half::bf16, u64) {
1965            if x.is_empty() {
1966                panic!("Cannot find minimum of empty vector");
1967            }
1968            let (idx, val) = x
1969                .iter()
1970                .enumerate()
1971                .min_by(|a, b| {
1972                    match a.1.partial_cmp(b.1) {
1973                        Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), // Keep index order for equal values
1974                        Some(ordering) => ordering,
1975                        None => std::cmp::Ordering::Equal,
1976                    }
1977                })
1978                .unwrap_or((0, &x[0]));
1979            (*val, idx as u64)
1980        }
1981    };
1982}
1983
1984/// Generate vectorized min and max values for all numeric types
1985#[macro_export]
1986macro_rules! impl_min_max_v {
1987    ($($t:ty),+) => {
1988        $(
1989            paste::paste! {
1990                #[inline(always)]
1991                fn [<min_max_v_ $t>](&self, x: &[$t]) -> ($t, $t) {
1992                    if x.is_empty() {
1993                        panic!("Cannot find min/max of empty vector");
1994                    }
1995                    let min_val = *x.iter().min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0]);
1996                    let max_val = *x.iter().max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)).unwrap_or(&x[0]);
1997                    (min_val, max_val)
1998                }
1999            }
2000        )+
2001    };
2002}
2003
2004/// Generate vectorized min and max values with indices for all numeric types
2005#[macro_export]
2006macro_rules! impl_min_max_vi {
2007    ($($t:ty),+) => {
2008        $(
2009            paste::paste! {
2010                #[inline(always)]
2011                fn [<min_max_vi_ $t>](&self, x: &[$t]) -> (($t, u64), ($t, u64)) {
2012                    if x.is_empty() {
2013                        panic!("Cannot find min/max of empty vector");
2014                    }
2015                    let mut min_val = x[0];
2016                    let mut min_idx = 0;
2017                    let mut max_val = x[0];
2018                    let mut max_idx = 0;
2019
2020                    for (i, &val) in x.iter().enumerate() {
2021                        if val < min_val {
2022                            min_val = val;
2023                            min_idx = i;
2024                        }
2025                        if val > max_val {
2026                            max_val = val;
2027                            max_idx = i;
2028                        }
2029                    }
2030
2031                    ((min_val, min_idx as u64), (max_val, max_idx as u64))
2032                }
2033            }
2034        )+
2035    };
2036}
2037
2038/// Generate vectorized min and max indices for all numeric types
2039#[macro_export]
2040macro_rules! impl_min_max_i {
2041    ($($t:ty),+) => {
2042        $(
2043            paste::paste! {
2044                #[inline(always)]
2045                fn [<min_max_i_ $t>](&self, x: &[$t]) -> (u64, u64) {
2046                    if x.is_empty() {
2047                        panic!("Cannot find min/max indices of empty vector");
2048                    }
2049                    let min_idx = x.iter()
2050                        .enumerate()
2051                        .min_by(|a, b| {
2052                            match a.1.partial_cmp(b.1) {
2053                                Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), // Keep index order for equal values
2054                                Some(ordering) => ordering,
2055                                None => std::cmp::Ordering::Equal,
2056                            }
2057                        })
2058                        .unwrap_or((0, &x[0]))
2059                        .0 as u64;
2060                    let max_idx = x.iter()
2061                        .enumerate()
2062                        .max_by(|a, b| {
2063                            match a.1.partial_cmp(b.1) {
2064                                Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), // Reverse index order for equal values
2065                                Some(ordering) => ordering,
2066                                None => std::cmp::Ordering::Equal,
2067                            }
2068                        })
2069                        .unwrap_or((0, &x[0]))
2070                        .0 as u64;
2071                    (min_idx, max_idx)
2072                }
2073            }
2074        )+
2075    };
2076}
2077
2078/// Generate vectorized min and max values for half precision types
2079#[macro_export]
2080macro_rules! impl_min_max_v_half {
2081    () => {
2082        #[inline(always)]
2083        fn min_max_v_f16(&self, x: &[half::f16]) -> (half::f16, half::f16) {
2084            if x.is_empty() {
2085                panic!("Cannot find min/max of empty vector");
2086            }
2087            let min_val = *x
2088                .iter()
2089                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2090                .unwrap_or(&x[0]);
2091            let max_val = *x
2092                .iter()
2093                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2094                .unwrap_or(&x[0]);
2095            (min_val, max_val)
2096        }
2097
2098        #[inline(always)]
2099        fn min_max_v_bf16(&self, x: &[half::bf16]) -> (half::bf16, half::bf16) {
2100            if x.is_empty() {
2101                panic!("Cannot find min/max of empty vector");
2102            }
2103            let min_val = *x
2104                .iter()
2105                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2106                .unwrap_or(&x[0]);
2107            let max_val = *x
2108                .iter()
2109                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
2110                .unwrap_or(&x[0]);
2111            (min_val, max_val)
2112        }
2113    };
2114}
2115
2116/// Generate vectorized min and max values with indices for half precision types
2117#[macro_export]
2118macro_rules! impl_min_max_vi_half {
2119    () => {
2120        #[inline(always)]
2121        fn min_max_vi_f16(&self, x: &[half::f16]) -> ((half::f16, u64), (half::f16, u64)) {
2122            if x.is_empty() {
2123                panic!("Cannot find min/max of empty vector");
2124            }
2125            let mut min_val = x[0];
2126            let mut min_idx = 0;
2127            let mut max_val = x[0];
2128            let mut max_idx = 0;
2129
2130            for (i, &val) in x.iter().enumerate() {
2131                if val < min_val {
2132                    min_val = val;
2133                    min_idx = i;
2134                }
2135                if val > max_val {
2136                    max_val = val;
2137                    max_idx = i;
2138                }
2139            }
2140
2141            ((min_val, min_idx as u64), (max_val, max_idx as u64))
2142        }
2143
2144        #[inline(always)]
2145        fn min_max_vi_bf16(&self, x: &[half::bf16]) -> ((half::bf16, u64), (half::bf16, u64)) {
2146            if x.is_empty() {
2147                panic!("Cannot find min/max of empty vector");
2148            }
2149            let mut min_val = x[0];
2150            let mut min_idx = 0;
2151            let mut max_val = x[0];
2152            let mut max_idx = 0;
2153
2154            for (i, &val) in x.iter().enumerate() {
2155                if val < min_val {
2156                    min_val = val;
2157                    min_idx = i;
2158                }
2159                if val > max_val {
2160                    max_val = val;
2161                    max_idx = i;
2162                }
2163            }
2164
2165            ((min_val, min_idx as u64), (max_val, max_idx as u64))
2166        }
2167    };
2168}
2169
2170/// Generate vectorized min and max indices for half precision types
2171#[macro_export]
2172macro_rules! impl_min_max_i_half {
2173    () => {
2174        #[inline(always)]
2175        fn min_max_i_f16(&self, x: &[half::f16]) -> (u64, u64) {
2176            if x.is_empty() {
2177                panic!("Cannot find min/max indices of empty vector");
2178            }
2179            let min_idx = x
2180                .iter()
2181                .enumerate()
2182                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2183                .unwrap_or((0, &x[0]))
2184                .0 as u64;
2185            let max_idx = x
2186                .iter()
2187                .enumerate()
2188                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2189                .unwrap_or((0, &x[0]))
2190                .0 as u64;
2191            (min_idx, max_idx)
2192        }
2193
2194        #[inline(always)]
2195        fn min_max_i_bf16(&self, x: &[half::bf16]) -> (u64, u64) {
2196            if x.is_empty() {
2197                panic!("Cannot find min/max indices of empty vector");
2198            }
2199            let min_idx = x
2200                .iter()
2201                .enumerate()
2202                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2203                .unwrap_or((0, &x[0]))
2204                .0 as u64;
2205            let max_idx = x
2206                .iter()
2207                .enumerate()
2208                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2209                .unwrap_or((0, &x[0]))
2210                .0 as u64;
2211            (min_idx, max_idx)
2212        }
2213    };
2214}
2215
2216/// Generate vectorized minimum index for all numeric types
2217#[macro_export]
2218macro_rules! impl_min_i {
2219    ($($t:ty),+) => {
2220        $(
2221            paste::paste! {
2222                #[inline(always)]
2223                fn [<min_i_ $t>](&self, x: &[$t]) -> u64 {
2224                    if x.is_empty() {
2225                        panic!("Cannot find minimum index of empty vector");
2226                    }
2227                    x.iter()
2228                        .enumerate()
2229                        .min_by(|a, b| {
2230                            match a.1.partial_cmp(b.1) {
2231                                Some(std::cmp::Ordering::Equal) => a.0.cmp(&b.0), // Keep index order for equal values
2232                                Some(ordering) => ordering,
2233                                None => std::cmp::Ordering::Equal,
2234                            }
2235                        })
2236                        .unwrap_or((0, &x[0]))
2237                        .0 as u64
2238                }
2239            }
2240        )+
2241    };
2242}
2243
2244/// Generate vectorized maximum index for all numeric types
2245#[macro_export]
2246macro_rules! impl_max_i {
2247    ($($t:ty),+) => {
2248        $(
2249            paste::paste! {
2250                #[inline(always)]
2251                fn [<max_i_ $t>](&self, x: &[$t]) -> u64 {
2252                    if x.is_empty() {
2253                        panic!("Cannot find maximum index of empty vector");
2254                    }
2255                    x.iter()
2256                        .enumerate()
2257                        .max_by(|a, b| {
2258                            match a.1.partial_cmp(b.1) {
2259                                Some(std::cmp::Ordering::Equal) => b.0.cmp(&a.0), // Reverse index order for equal values
2260                                Some(ordering) => ordering,
2261                                None => std::cmp::Ordering::Equal,
2262                            }
2263                        })
2264                        .unwrap_or((0, &x[0]))
2265                        .0 as u64
2266                }
2267            }
2268        )+
2269    };
2270}
2271
2272/// Generate vectorized minimum index for half precision types
2273#[macro_export]
2274macro_rules! impl_min_i_half {
2275    () => {
2276        #[inline(always)]
2277        fn min_i_f16(&self, x: &[half::f16]) -> u64 {
2278            if x.is_empty() {
2279                panic!("Cannot find minimum index of empty vector");
2280            }
2281            x.iter()
2282                .enumerate()
2283                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2284                .unwrap_or((0, &x[0]))
2285                .0 as u64
2286        }
2287
2288        #[inline(always)]
2289        fn min_i_bf16(&self, x: &[half::bf16]) -> u64 {
2290            if x.is_empty() {
2291                panic!("Cannot find minimum index of empty vector");
2292            }
2293            x.iter()
2294                .enumerate()
2295                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2296                .unwrap_or((0, &x[0]))
2297                .0 as u64
2298        }
2299    };
2300}
2301
2302/// Generate vectorized maximum index for half precision types
2303#[macro_export]
2304macro_rules! impl_max_i_half {
2305    () => {
2306        #[inline(always)]
2307        fn max_i_f16(&self, x: &[half::f16]) -> u64 {
2308            if x.is_empty() {
2309                panic!("Cannot find maximum index of empty vector");
2310            }
2311            x.iter()
2312                .enumerate()
2313                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2314                .unwrap_or((0, &x[0]))
2315                .0 as u64
2316        }
2317
2318        #[inline(always)]
2319        fn max_i_bf16(&self, x: &[half::bf16]) -> u64 {
2320            if x.is_empty() {
2321                panic!("Cannot find maximum index of empty vector");
2322            }
2323            x.iter()
2324                .enumerate()
2325                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
2326                .unwrap_or((0, &x[0]))
2327                .0 as u64
2328        }
2329    };
2330}