1include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
24
25use core::simd::{LaneCount, Mask, Simd, SimdElement, SupportedLaneCount};
26use std::ops::{Add, Div, Mul, Rem, Sub};
27use std::simd::StdFloat;
28use std::simd::cmp::SimdPartialEq;
29
30use minarrow::Bitmask;
31use num_traits::{One, PrimInt, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub, Zero};
32
33use crate::kernels::bitmask::simd::all_true_mask_simd;
34use crate::operators::ArithmeticOperator;
35use crate::utils::simd_mask;
36
37#[inline(always)]
41pub fn int_dense_body_simd<T, const LANES: usize>(
42 op: ArithmeticOperator,
43 lhs: &[T],
44 rhs: &[T],
45 out: &mut [T],
46) where
47 T: Copy + One + PrimInt + ToPrimitive + Zero + SimdElement + WrappingMul,
48 LaneCount<LANES>: SupportedLaneCount,
49 Simd<T, LANES>: Add<Output = Simd<T, LANES>>
50 + Sub<Output = Simd<T, LANES>>
51 + Mul<Output = Simd<T, LANES>>
52 + Div<Output = Simd<T, LANES>>
53 + Rem<Output = Simd<T, LANES>>,
54{
55 let n = lhs.len();
56 let mut vectorisable = n / LANES * LANES;
57 let mut i = 0;
58 while i < vectorisable {
59 let a = Simd::<T, LANES>::from_slice(&lhs[i..i + LANES]);
60 let b = Simd::<T, LANES>::from_slice(&rhs[i..i + LANES]);
61 let r = match op {
62 ArithmeticOperator::Add => a + b,
63 ArithmeticOperator::Subtract => a - b,
64 ArithmeticOperator::Multiply => a * b,
65 ArithmeticOperator::Divide => a / b, ArithmeticOperator::Remainder => a % b, ArithmeticOperator::Power => {
68 vectorisable = 0;
69 break;
70 }
71 };
72 r.copy_to_slice(&mut out[i..i + LANES]);
73 i += LANES;
74 }
75
76 for idx in vectorisable..n {
78 out[idx] = match op {
79 ArithmeticOperator::Add => lhs[idx] + rhs[idx],
80 ArithmeticOperator::Subtract => lhs[idx] - rhs[idx],
81 ArithmeticOperator::Multiply => lhs[idx] * rhs[idx],
82 ArithmeticOperator::Divide => lhs[idx] / rhs[idx], ArithmeticOperator::Remainder => lhs[idx] % rhs[idx], ArithmeticOperator::Power => {
85 let mut acc = T::one();
86 let exp = rhs[idx].to_u32().unwrap_or(0);
87 for _ in 0..exp {
88 acc = acc.wrapping_mul(&lhs[idx]);
89 }
90 acc
91 }
92 };
93 }
94}
95
96#[inline(always)]
99pub fn int_masked_body_simd<T, const LANES: usize>(
100 op: ArithmeticOperator,
101 lhs: &[T],
102 rhs: &[T],
103 mask: &Bitmask,
104 out: &mut [T],
105 out_mask: &mut Bitmask,
106) where
107 T: Copy
108 + PrimInt
109 + ToPrimitive
110 + Zero
111 + One
112 + SimdElement
113 + PartialEq
114 + WrappingAdd
115 + WrappingMul
116 + WrappingSub,
117 LaneCount<LANES>: SupportedLaneCount,
118 Simd<T, LANES>: Add<Output = Simd<T, LANES>>
119 + SimdPartialEq<Mask = Mask<<T as SimdElement>::Mask, LANES>>
120 + Sub<Output = Simd<T, LANES>>
121 + Mul<Output = Simd<T, LANES>>
122 + Div<Output = Simd<T, LANES>>
123 + Rem<Output = Simd<T, LANES>>,
124{
125 let n = lhs.len();
126 let dense = all_true_mask_simd(mask);
127
128 if dense {
134 let vectorisable = n / LANES * LANES;
136 let mut i = 0;
137 while i < vectorisable {
138 let a = Simd::<T, LANES>::from_slice(&lhs[i..i + LANES]);
139 let b = Simd::<T, LANES>::from_slice(&rhs[i..i + LANES]);
140
141 let (r, valid): (Simd<T, LANES>, Mask<<T as SimdElement>::Mask, LANES>) = match op {
142 ArithmeticOperator::Add => (a + b, Mask::splat(true)),
143 ArithmeticOperator::Subtract => (a - b, Mask::splat(true)),
144 ArithmeticOperator::Multiply => (a * b, Mask::splat(true)),
145 ArithmeticOperator::Power => {
146 let mut tmp = [T::zero(); LANES];
147 for l in 0..LANES {
148 tmp[l] = a[l].pow(b[l].to_u32().unwrap_or(0));
149 }
150 (Simd::<T, LANES>::from_array(tmp), Mask::splat(true))
151 }
152 ArithmeticOperator::Divide | ArithmeticOperator::Remainder => {
153 let div_zero = b.simd_eq(Simd::splat(T::zero()));
154 let valid = !div_zero;
155 let safe_b = div_zero.select(Simd::splat(T::one()), b);
156 let r = match op {
157 ArithmeticOperator::Divide => a / safe_b,
158 ArithmeticOperator::Remainder => a % safe_b,
159 _ => unreachable!(),
160 };
161 let r = div_zero.select(Simd::splat(T::zero()), r);
162 (r, valid)
163 }
164 };
165 r.copy_to_slice(&mut out[i..i + LANES]);
166 let valid_bits = valid.to_bitmask();
168 for l in 0..LANES {
169 let idx = i + l;
170 if idx < n {
171 unsafe {
172 out_mask.set_unchecked(idx, ((valid_bits >> l) & 1) == 1);
173 }
174 }
175 }
176 i += LANES;
177 }
178 for idx in vectorisable..n {
180 match op {
181 ArithmeticOperator::Add => {
182 out[idx] = lhs[idx].wrapping_add(&rhs[idx]);
183 unsafe {
184 out_mask.set_unchecked(idx, true);
185 }
186 }
187 ArithmeticOperator::Subtract => {
188 out[idx] = lhs[idx].wrapping_sub(&rhs[idx]);
189 unsafe {
190 out_mask.set_unchecked(idx, true);
191 }
192 }
193 ArithmeticOperator::Multiply => {
194 out[idx] = lhs[idx].wrapping_mul(&rhs[idx]);
195 unsafe {
196 out_mask.set_unchecked(idx, true);
197 }
198 }
199 ArithmeticOperator::Power => {
200 out[idx] = lhs[idx].pow(rhs[idx].to_u32().unwrap_or(0));
201 unsafe {
202 out_mask.set_unchecked(idx, true);
203 }
204 }
205 ArithmeticOperator::Divide | ArithmeticOperator::Remainder => {
206 if rhs[idx] == T::zero() {
207 out[idx] = T::zero();
208 unsafe {
209 out_mask.set_unchecked(idx, false);
210 }
211 } else {
212 out[idx] = match op {
213 ArithmeticOperator::Divide => lhs[idx] / rhs[idx],
214 ArithmeticOperator::Remainder => lhs[idx] % rhs[idx],
215 _ => unreachable!(),
216 };
217 unsafe {
218 out_mask.set_unchecked(idx, true);
219 }
220 }
221 }
222 }
223 }
224 return;
225 }
226
227 let mut i = 0;
228 while i + LANES <= n {
229 let a = Simd::<T, LANES>::from_slice(&lhs[i..i + LANES]);
230 let b = Simd::<T, LANES>::from_slice(&rhs[i..i + LANES]);
231 let m_src: Mask<_, LANES> = simd_mask::<_, LANES>(mask, i, n); let div_zero: Mask<_, LANES> = b.simd_eq(Simd::splat(T::zero()));
235
236 let res = match op {
238 ArithmeticOperator::Add => a + b,
239 ArithmeticOperator::Subtract => a - b,
240 ArithmeticOperator::Multiply => a * b,
241 ArithmeticOperator::Divide => {
242 let safe_b = div_zero.select(Simd::splat(T::one()), b); let q = a / safe_b;
244 div_zero.select(Simd::splat(T::zero()), q) }
246 ArithmeticOperator::Remainder => {
247 let safe_b = div_zero.select(Simd::splat(T::one()), b);
248 let r = a % safe_b;
249 div_zero.select(Simd::splat(T::zero()), r)
250 }
251 ArithmeticOperator::Power => {
252 let mut tmp = [T::zero(); LANES];
254 for l in 0..LANES {
255 tmp[l] = a[l].pow(b[l].to_u32().unwrap_or(0));
256 }
257 Simd::<T, LANES>::from_array(tmp)
258 }
259 };
260
261 let selected = m_src.select(res, Simd::splat(T::zero()));
263 selected.copy_to_slice(&mut out[i..i + LANES]);
264
265 let final_mask = match op {
267 ArithmeticOperator::Divide | ArithmeticOperator::Remainder => {
268 m_src & !div_zero
270 }
271 _ => m_src,
272 };
273 let mbits = final_mask.to_bitmask();
274 for l in 0..LANES {
275 let idx = i + l;
276 if idx < n {
277 unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
278 }
279 }
280 i += LANES;
281 }
282
283 for j in i..n {
285 let valid = unsafe { mask.get_unchecked(j) };
286 if valid {
287 let (result, final_valid) = match op {
288 ArithmeticOperator::Add => (lhs[j].wrapping_add(&rhs[j]), true),
289 ArithmeticOperator::Subtract => (lhs[j].wrapping_sub(&rhs[j]), true),
290 ArithmeticOperator::Multiply => (lhs[j].wrapping_mul(&rhs[j]), true),
291 ArithmeticOperator::Divide => {
292 if rhs[j] == T::zero() {
293 (T::zero(), false) } else {
295 (lhs[j] / rhs[j], true)
296 }
297 }
298 ArithmeticOperator::Remainder => {
299 if rhs[j] == T::zero() {
300 (T::zero(), false) } else {
302 (lhs[j] % rhs[j], true)
303 }
304 }
305 ArithmeticOperator::Power => (lhs[j].pow(rhs[j].to_u32().unwrap_or(0)), true),
306 };
307 out[j] = result;
308 unsafe { out_mask.set_unchecked(j, final_valid) };
309 } else {
310 out[j] = T::zero();
311 unsafe { out_mask.set_unchecked(j, false) };
312 }
313 }
314}
315
316#[inline(always)]
320pub fn float_masked_body_f32_simd<const LANES: usize>(
321 op: ArithmeticOperator,
322 lhs: &[f32],
323 rhs: &[f32],
324 mask: &Bitmask,
325 out: &mut [f32],
326 out_mask: &mut Bitmask,
327) where
328 LaneCount<LANES>: SupportedLaneCount,
329{
330 type M = <f32 as SimdElement>::Mask;
331
332 let n = lhs.len();
333 let mut i = 0;
334 let dense = all_true_mask_simd(mask);
335
336 while i + LANES <= n {
337 let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
338 let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
339 let m: Mask<M, LANES> = if dense {
340 Mask::splat(true)
341 } else {
342 simd_mask::<M, LANES>(mask, i, n)
343 };
344
345 let res = match op {
346 ArithmeticOperator::Add => a + b,
347 ArithmeticOperator::Subtract => a - b,
348 ArithmeticOperator::Multiply => a * b,
349 ArithmeticOperator::Divide => a / b,
350 ArithmeticOperator::Remainder => a % b,
351 ArithmeticOperator::Power => (b * a.ln()).exp(),
352 };
353
354 let selected = m.select(res, Simd::<f32, LANES>::splat(0.0));
355 selected.copy_to_slice(&mut out[i..i + LANES]);
356
357 let mbits = m.to_bitmask();
358 for l in 0..LANES {
359 let idx = i + l;
360 if idx < n {
361 unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
362 }
363 }
364 i += LANES;
365 }
366
367 for j in i..n {
369 let valid = dense || unsafe { mask.get_unchecked(j) };
370 if valid {
371 out[j] = match op {
372 ArithmeticOperator::Add => lhs[j] + rhs[j],
373 ArithmeticOperator::Subtract => lhs[j] - rhs[j],
374 ArithmeticOperator::Multiply => lhs[j] * rhs[j],
375 ArithmeticOperator::Divide => lhs[j] / rhs[j],
376 ArithmeticOperator::Remainder => lhs[j] % rhs[j],
377 ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
378 };
379 unsafe { out_mask.set_unchecked(j, true) };
380 } else {
381 out[j] = 0.0;
382 unsafe { out_mask.set_unchecked(j, false) };
383 }
384 }
385}
386
387#[inline(always)]
391pub fn float_masked_body_f64_simd<const LANES: usize>(
392 op: ArithmeticOperator,
393 lhs: &[f64],
394 rhs: &[f64],
395 mask: &Bitmask,
396 out: &mut [f64],
397 out_mask: &mut Bitmask,
398) where
399 LaneCount<LANES>: SupportedLaneCount,
400{
401 type M = <f64 as SimdElement>::Mask;
402
403 let n = lhs.len();
404 let dense = all_true_mask_simd(mask);
405
406 if dense {
407 float_dense_body_f64_simd::<LANES>(op, lhs, rhs, out);
409 out_mask.fill(true);
410 return;
411 }
412
413 let mut i = 0;
414 while i + LANES <= n {
415 let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
416 let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
417 let m: Mask<M, LANES> = simd_mask::<M, LANES>(mask, i, n);
418
419 let res = match op {
420 ArithmeticOperator::Add => a + b,
421 ArithmeticOperator::Subtract => a - b,
422 ArithmeticOperator::Multiply => a * b,
423 ArithmeticOperator::Divide => a / b,
424 ArithmeticOperator::Remainder => a % b,
425 ArithmeticOperator::Power => (b * a.ln()).exp(),
426 };
427
428 let selected = m.select(res, Simd::<f64, LANES>::splat(0.0));
429 selected.copy_to_slice(&mut out[i..i + LANES]);
430
431 let mbits = m.to_bitmask();
432 for l in 0..LANES {
433 let idx = i + l;
434 if idx < n {
435 unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
436 }
437 }
438 i += LANES;
439 }
440
441 for j in i..n {
443 let valid = unsafe { mask.get_unchecked(j) };
444 if valid {
445 out[j] = match op {
446 ArithmeticOperator::Add => lhs[j] + rhs[j],
447 ArithmeticOperator::Subtract => lhs[j] - rhs[j],
448 ArithmeticOperator::Multiply => lhs[j] * rhs[j],
449 ArithmeticOperator::Divide => lhs[j] / rhs[j],
450 ArithmeticOperator::Remainder => lhs[j] % rhs[j],
451 ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
452 };
453 unsafe { out_mask.set_unchecked(j, true) };
454 } else {
455 out[j] = 0.0;
456 unsafe { out_mask.set_unchecked(j, false) };
457 }
458 }
459}
460
461#[inline(always)]
465pub fn float_dense_body_f32_simd<const LANES: usize>(
466 op: ArithmeticOperator,
467 lhs: &[f32],
468 rhs: &[f32],
469 out: &mut [f32],
470) where
471 LaneCount<LANES>: SupportedLaneCount,
472{
473 let n = lhs.len();
474 let mut i = 0;
475 while i + LANES <= n {
476 let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
477 let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
478 let res = match op {
479 ArithmeticOperator::Add => a + b,
480 ArithmeticOperator::Subtract => a - b,
481 ArithmeticOperator::Multiply => a * b,
482 ArithmeticOperator::Divide => a / b,
483 ArithmeticOperator::Remainder => a % b,
484 ArithmeticOperator::Power => (b * a.ln()).exp(),
485 };
486 res.copy_to_slice(&mut out[i..i + LANES]);
487 i += LANES;
488 }
489
490 for j in i..n {
492 out[j] = match op {
493 ArithmeticOperator::Add => lhs[j] + rhs[j],
494 ArithmeticOperator::Subtract => lhs[j] - rhs[j],
495 ArithmeticOperator::Multiply => lhs[j] * rhs[j],
496 ArithmeticOperator::Divide => lhs[j] / rhs[j],
497 ArithmeticOperator::Remainder => lhs[j] % rhs[j],
498 ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
499 };
500 }
501}
502
503#[inline(always)]
507pub fn float_dense_body_f64_simd<const LANES: usize>(
508 op: ArithmeticOperator,
509 lhs: &[f64],
510 rhs: &[f64],
511 out: &mut [f64],
512) where
513 LaneCount<LANES>: SupportedLaneCount,
514{
515 let n = lhs.len();
516 let mut i = 0;
517 while i + LANES <= n {
518 let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
519 let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
520 let res = match op {
521 ArithmeticOperator::Add => a + b,
522 ArithmeticOperator::Subtract => a - b,
523 ArithmeticOperator::Multiply => a * b,
524 ArithmeticOperator::Divide => a / b,
525 ArithmeticOperator::Remainder => a % b,
526 ArithmeticOperator::Power => (b * a.ln()).exp(),
527 };
528 res.copy_to_slice(&mut out[i..i + LANES]);
529 i += LANES;
530 }
531
532 for j in i..n {
534 out[j] = match op {
535 ArithmeticOperator::Add => lhs[j] + rhs[j],
536 ArithmeticOperator::Subtract => lhs[j] - rhs[j],
537 ArithmeticOperator::Multiply => lhs[j] * rhs[j],
538 ArithmeticOperator::Divide => lhs[j] / rhs[j],
539 ArithmeticOperator::Remainder => lhs[j] % rhs[j],
540 ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(),
541 };
542 }
543}
544
545#[inline(always)]
548pub fn fma_masked_body_f32_simd<const LANES: usize>(
549 lhs: &[f32],
550 rhs: &[f32],
551 acc: &[f32],
552 mask: &Bitmask,
553 out: &mut [f32],
554 out_mask: &mut minarrow::Bitmask,
555) where
556 LaneCount<LANES>: SupportedLaneCount,
557{
558 use core::simd::{Mask, Simd};
559
560 let n = lhs.len();
561 let mut i = 0;
562 let dense = all_true_mask_simd(mask);
563
564 if dense {
565 fma_dense_body_f32_simd::<LANES>(lhs, rhs, acc, out);
566 out_mask.fill(true);
567 return;
568 }
569
570 while i + LANES <= n {
571 let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
572 let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
573 let c = Simd::<f32, LANES>::from_slice(&acc[i..i + LANES]);
574 let m: Mask<i32, LANES> = simd_mask::<i32, LANES>(mask, i, n);
575
576 let res = a.mul_add(b, c);
577
578 let selected = m.select(res, Simd::<f32, LANES>::splat(0.0));
579 selected.copy_to_slice(&mut out[i..i + LANES]);
580
581 let mbits = m.to_bitmask();
582 for l in 0..LANES {
583 let idx = i + l;
584 if idx < n {
585 unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
586 }
587 }
588 i += LANES;
589 }
590
591 for j in i..n {
594 let valid = unsafe { mask.get_unchecked(j) };
595 if valid {
596 out[j] = lhs[j].mul_add(rhs[j], acc[j]);
597 unsafe { out_mask.set_unchecked(j, true) };
598 } else {
599 out[j] = 0.0;
600 unsafe { out_mask.set_unchecked(j, false) };
601 }
602 }
603}
604
605#[inline(always)]
608pub fn fma_masked_body_f64_simd<const LANES: usize>(
609 lhs: &[f64],
610 rhs: &[f64],
611 acc: &[f64],
612 mask: &Bitmask,
613 out: &mut [f64],
614 out_mask: &mut minarrow::Bitmask,
615) where
616 LaneCount<LANES>: SupportedLaneCount,
617{
618 use core::simd::{Mask, Simd};
619
620 let n = lhs.len();
621 let mut i = 0;
622 let dense = all_true_mask_simd(mask);
623
624 if dense {
625 fma_dense_body_f64_simd::<LANES>(lhs, rhs, acc, out);
627 out_mask.fill(true);
628 return;
629 }
630
631 while i + LANES <= n {
632 let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
633 let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
634 let c = Simd::<f64, LANES>::from_slice(&acc[i..i + LANES]);
635 let m: Mask<i64, LANES> = simd_mask::<i64, LANES>(mask, i, n);
636
637 let res = a.mul_add(b, c);
638
639 let selected = m.select(res, Simd::<f64, LANES>::splat(0.0));
640 selected.copy_to_slice(&mut out[i..i + LANES]);
641
642 let mbits = m.to_bitmask();
643 for l in 0..LANES {
644 let idx = i + l;
645 if idx < n {
646 unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) };
647 }
648 }
649 i += LANES;
650 }
651
652 for j in i..n {
655 let valid = unsafe { mask.get_unchecked(j) };
656 if valid {
657 out[j] = lhs[j].mul_add(rhs[j], acc[j]);
658 unsafe { out_mask.set_unchecked(j, true) };
659 } else {
660 out[j] = 0.0;
661 unsafe { out_mask.set_unchecked(j, false) };
662 }
663 }
664}
665
666#[inline(always)]
669pub fn fma_dense_body_f32_simd<const LANES: usize>(
670 lhs: &[f32],
671 rhs: &[f32],
672 acc: &[f32],
673 out: &mut [f32],
674) where
675 LaneCount<LANES>: SupportedLaneCount,
676{
677 use core::simd::Simd;
678
679 let n = lhs.len();
680 let mut i = 0;
681
682 while i + LANES <= n {
683 let a = Simd::<f32, LANES>::from_slice(&lhs[i..i + LANES]);
684 let b = Simd::<f32, LANES>::from_slice(&rhs[i..i + LANES]);
685 let c = Simd::<f32, LANES>::from_slice(&acc[i..i + LANES]);
686 let res = a.mul_add(b, c);
687 res.copy_to_slice(&mut out[i..i + LANES]);
688 i += LANES;
689 }
690
691 for j in i..n {
692 out[j] = lhs[j].mul_add(rhs[j], acc[j]);
693 }
694}
695
696#[inline(always)]
699pub fn fma_dense_body_f64_simd<const LANES: usize>(
700 lhs: &[f64],
701 rhs: &[f64],
702 acc: &[f64],
703 out: &mut [f64],
704) where
705 LaneCount<LANES>: SupportedLaneCount,
706{
707 use core::simd::Simd;
708
709 let n = lhs.len();
710 let mut i = 0;
711
712 while i + LANES <= n {
713 let a = Simd::<f64, LANES>::from_slice(&lhs[i..i + LANES]);
714 let b = Simd::<f64, LANES>::from_slice(&rhs[i..i + LANES]);
715 let c = Simd::<f64, LANES>::from_slice(&acc[i..i + LANES]);
716 let res = a.mul_add(b, c);
717 res.copy_to_slice(&mut out[i..i + LANES]);
718 i += LANES;
719 }
720
721 for j in i..n {
723 out[j] = lhs[j].mul_add(rhs[j], acc[j]);
724 }
725}