1#![no_std]
2#![forbid(unsafe_code)]
3
4use core::{fmt, mem, ops};
7
8#[repr(C)]
9#[derive(Clone, Copy)]
10pub struct Simd<TArray>(TArray, <Self as Sealed>::Align) where Self: Vector;
11
12pub trait Vector: Copy + Sealed {
13 type Element: Copy;
14 type MaskVector: Vector;
15}
16
17fn simd<TArray>(array: TArray) -> Simd<TArray>
18 where Simd<TArray>: Vector
19{
20 Simd(array, Default::default())
21}
22
23impl<TArray> ops::Deref for Simd<TArray>
24 where Self: Vector
25{
26 type Target = TArray;
27 #[inline]
28 fn deref(&self) -> &TArray {
29 &self.0
30 }
31}
32impl<TArray> ops::DerefMut for Simd<TArray>
33 where Self: Vector
34{
35 #[inline]
36 fn deref_mut(&mut self) -> &mut TArray {
37 &mut self.0
38 }
39}
40
41macro_rules! define_vector_type {
42 ($($a:ident ($(@$m:ident $u:ident $($t:ident $p:ident $n:literal)+)+))+) => {$($($(
43 #[allow(non_camel_case_types)]
44 pub type $t = Simd<[$p; $n]>;
45
46 impl Sealed for Simd<[$p; $n]> {
47 type Align = $a;
48 }
49 impl Vector for Simd<[$p; $n]> {
50 type Element = $p;
51 type MaskVector = $m;
52 }
53
54 impl SimdImpl for Simd<[$p; $n]> {
55 fn as_slice(&self) -> &[Self::Element] {
56 &self.0
57 }
58
59 type Array = [Self::Element; $n];
60 #[inline]
61 fn repeat(value: Self::Element) -> Self {
62 simd([value; $n])
63 }
64 #[inline]
65 fn map(self, f: impl Fn($p) -> $p) -> Self {
66 simd(array_utils::map(self.0, f))
67 }
68 #[inline]
69 fn zip(self, other: Self, f: impl Fn($p, $p) -> $p) -> Self {
70 simd(array_utils::zip(self.0, other.0, f))
71 }
72
73 type Mask = <Self::MaskVector as Vector>::Element;
74 #[inline]
75 fn zip_mask(self, other: Self, f: impl Fn($p, $p) -> Self::Mask) -> Self::MaskVector {
76 simd(array_utils::zip(self.0, other.0, f))
77 }
78 }
79
80 impl From<Simd<[$p; $n]>> for [$p; $n] {
81 #[inline]
82 fn from(simd: Simd<[$p; $n]>) -> Self {
83 simd.0
84 }
85 }
86
87 )+
88
89 impl From<$m> for $u {
90 #[inline]
91 fn from(mask: $m) -> $u {
92 simd(array_utils::map(mask.0, Into::into))
93 }
94 }
95
96 )+)+};
97}
98define_vector_type!(
99 Align8 (
100 @m8x8 u8x8
101 i8x8 i8 8
102 u8x8 u8 8
103 m8x8 m8 8
104
105 @m16x4 u16x4
106 i16x4 i16 4
107 u16x4 u16 4
108 m16x4 m16 4
109
110 @m32x2 u32x2
111 i32x2 i32 2
112 u32x2 u32 2
113 m32x2 m32 2
114 f32x2 f32 2
115
116 @m64x1 u64x1
117 i64x1 i64 1
118 u64x1 u64 1
119 m64x1 m64 1
120 f64x1 f64 1
121 )
122 Align16 (
123 @m8x16 u8x16
124 i8x16 i8 16
125 u8x16 u8 16
126 m8x16 m8 16
127
128 @m16x8 u16x8
129 i16x8 i16 8
130 u16x8 u16 8
131 m16x8 m16 8
132
133 @m32x4 u32x4
134 i32x4 i32 4
135 u32x4 u32 4
136 m32x4 m32 4
137 f32x4 f32 4
138
139 @m64x2 u64x2
140 i64x2 i64 2
141 u64x2 u64 2
142 m64x2 m64 2
143 f64x2 f64 2
144 )
145 Align32 (
146 @m8x32 u8x32
147 i8x32 i8 32
148 u8x32 u8 32
149 m8x32 m8 32
150
151 @m16x16 u16x16
152 i16x16 i16 16
153 u16x16 u16 16
154 m16x16 m16 16
155
156 @m32x8 u32x8
157 i32x8 i32 8
158 u32x8 u32 8
159 m32x8 m32 8
160 f32x8 f32 8
161
162 @m64x4 u64x4
163 i64x4 i64 4
164 u64x4 u64 4
165 m64x4 m64 4
166 f64x4 f64 4
167 )
168 Align64 (
169 @m8x64 u8x64
170 i8x64 i8 64
171 u8x64 u8 64
172 m8x64 m8 64
173
174 @m16x32 u16x32
175 i16x32 i16 32
176 u16x32 u16 32
177 m16x32 m16 32
178
179 @m32x16 u32x16
180 i32x16 i32 16
181 u32x16 u32 16
182 m32x16 m32 16
183 f32x16 f32 16
184
185 @m64x8 u64x8
186 i64x8 i64 8
187 u64x8 u64 8
188 m64x8 m64 8
189 f64x8 f64 8
190 )
191);
192
193impl<TArray> From<TArray> for Simd<TArray>
194 where Self: Vector
195{
196 #[inline]
197 fn from(array: TArray) -> Self {
198 simd(array)
199 }
200}
201
202impl<TArray> Simd<TArray>
203 where Self: SimdImpl
204{
205 #[inline]
206 pub fn splat(value: <Self as Vector>::Element) -> Self {
207 Self::repeat(value)
208 }
209}
210
211impl<TArray> Default for Simd<TArray>
212where
213 Self: SimdImpl,
214 <Self as Vector>::Element: Default
215{
216 #[inline]
217 fn default() -> Self {
218 Self::splat(Default::default())
219 }
220}
221
222impl<TArray> fmt::Debug for Simd<TArray>
223where
224 Self: SimdImpl,
225 <Self as Vector>::Element: fmt::Debug
226{
227 #[inline]
228 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
229 fmt::Debug::fmt(self.as_slice(), f)
230 }
231}
232
233impl<TArray> Simd<TArray>
234where
235 Self: SimdImpl,
236 <Self as Vector>::Element: PartialOrd
237{
238 #[inline]
239 pub fn eq(self, other: Self) -> <Self as Vector>::MaskVector {
240 self.zip_mask(other, |a, b| (a == b).into())
241 }
242
243 #[inline]
244 pub fn ne(self, other: Self) -> <Self as Vector>::MaskVector {
245 self.zip_mask(other, |a, b| (a != b).into())
246 }
247
248 #[inline]
249 pub fn lt(self, other: Self) -> <Self as Vector>::MaskVector {
250 self.zip_mask(other, |a, b| (a < b).into())
251 }
252
253 #[inline]
254 pub fn gt(self, other: Self) -> <Self as Vector>::MaskVector {
255 self.zip_mask(other, |a, b| (a > b).into())
256 }
257
258 #[inline]
259 pub fn le(self, other: Self) -> <Self as Vector>::MaskVector {
260 self.zip_mask(other, |a, b| (a <= b).into())
261 }
262
263 #[inline]
264 pub fn ge(self, other: Self) -> <Self as Vector>::MaskVector {
265 self.zip_mask(other, |a, b| (a >= b).into())
266 }
267}
268
269impl<TArray> Simd<TArray>
270where
271 Self: SimdImpl,
272 <Self as Vector>::Element: Ord
273{
274 #[inline]
275 pub fn min(self, other: Self) -> Self {
276 self.zip(other, Ord::min)
277 }
278
279 #[inline]
280 pub fn max(self, other: Self) -> Self {
281 self.zip(other, Ord::max)
282 }
283}
284
285impl<TArray> Simd<TArray>
286where
287 Self: SimdImpl,
288 <Self as Vector>::Element: Integer
289{
290 #[inline]
291 pub fn wrapping_add(self, other: Self) -> Self {
292 self.zip(other, Integer::wrapping_add)
293 }
294
295 #[inline]
296 pub fn wrapping_sub(self, other: Self) -> Self {
297 self.zip(other, Integer::wrapping_sub)
298 }
299
300 #[inline]
301 pub fn wrapping_mul(self, other: Self) -> Self {
302 self.zip(other, Integer::wrapping_mul)
303 }
304
305 #[inline]
306 pub fn high_mul(self, other: Self) -> Self {
307 self.zip(other, Integer::high_mul)
308 }
309
310 #[inline]
311 pub fn saturating_add(self, other: Self) -> Self {
312 self.zip(other, Integer::saturating_add)
313 }
314
315 #[inline]
316 pub fn saturating_sub(self, other: Self) -> Self {
317 self.zip(other, Integer::saturating_sub)
318 }
319
320 #[inline]
321 pub fn count_ones(self) -> Self {
322 self.map(Integer::count_ones)
323 }
324
325 #[inline]
326 pub fn count_zeros(self) -> Self {
327 self.map(Integer::count_zeros)
328 }
329}
330
331impl<TArray> Simd<TArray>
332where
333 Self: SimdImpl,
334 <Self as Vector>::Element: SignedInteger
335{
336 #[inline]
337 pub fn wrapping_abs(self) -> Self {
338 self.map(SignedInteger::wrapping_abs)
339 }
340}
341
342impl<TArray> Simd<TArray>
343where
344 Self: SimdImpl,
345 <Self as Vector>::Element: FloatingPoint
346{
347 #[inline]
348 pub fn recip(self) -> Self {
349 self.map(FloatingPoint::recip)
350 }
351
352 #[inline]
353 pub fn to_degrees(self) -> Self {
354 self.map(FloatingPoint::to_degrees)
355 }
356
357 #[inline]
358 pub fn to_radians(self) -> Self {
359 self.map(FloatingPoint::to_radians)
360 }
361
362 #[inline]
363 pub fn min_naive(self, other: Self) -> Self {
364 self.zip(other, FloatingPoint::min_naive)
365 }
366
367 #[inline]
368 pub fn max_naive(self, other: Self) -> Self {
369 self.zip(other, FloatingPoint::max_naive)
370 }
371}
372
373macro_rules! forward_ops_as_zip {
374 ($($tr:ident $m:ident $(where $g:ident)? ,)+) => {$(
375 impl<TArray> ops::$tr for Simd<TArray>
376 where
377 Self: SimdImpl,
378 <Self as Vector>::Element: ops::$tr<Output = <Self as Vector>::Element>,
379 $( <Self as Vector>::Element: $g, )?
380 {
381 type Output = Self;
382 #[inline]
383 fn $m(self, other: Self) -> Self {
384 self.zip(other, ops::$tr::$m)
385 }
386 }
387 )+};
388}
389macro_rules! forward_ops_as_map {
390 ($($tr:ident $m:ident $(where $g:ident)? ,)+) => {$(
391 impl<TArray> ops::$tr for Simd<TArray>
392 where
393 Self: SimdImpl,
394 <Self as Vector>::Element: ops::$tr<Output = <Self as Vector>::Element>,
395 $( <Self as Vector>::Element: $g, )?
396 {
397 type Output = Self;
398 #[inline]
399 fn $m(self) -> Self {
400 self.map(ops::$tr::$m)
401 }
402 }
403 )+};
404}
405forward_ops_as_zip!(
406 BitAnd bitand,
407 BitOr bitor,
408 BitXor bitxor,
409
410 Add add where FloatingPoint,
411 Sub sub where FloatingPoint,
412 Mul mul where FloatingPoint,
413 Div div where FloatingPoint,
414 Rem rem where FloatingPoint,
415);
416forward_ops_as_map!(
417 Not not,
418
419 Neg neg where FloatingPoint,
420);
421
422use internals::*;
423mod internals {
424 pub trait Sealed {
425 type Align: Copy + Default;
426 }
427
428 pub trait SimdImpl: super::Vector {
429 fn as_slice(&self) -> &[Self::Element];
430
431 type Array;
432 fn repeat(value: Self::Element) -> Self;
433 fn map(self, f: impl Fn(Self::Element) -> Self::Element) -> Self;
434 fn zip(self, other: Self, f: impl Fn(Self::Element, Self::Element) -> Self::Element) -> Self;
435
436 type Mask: From<bool> + Into<bool>;
437 fn zip_mask(self, other: Self, f: impl Fn(Self::Element, Self::Element) -> Self::Mask) -> Self::MaskVector;
438 }
439
440 macro_rules! define_align_types {
441 ($($t:ident $n:literal)+) => {$(
442 #[repr(align($n))]
443 #[derive(Clone, Copy, Default)]
444 pub struct $t;
445 )+};
446 }
447 define_align_types!(
448 Align8 8
449 Align16 16
450 Align32 32
451 Align64 64
452 );
453
454 pub trait Integer {
455 fn wrapping_add(self, other: Self) -> Self;
456 fn wrapping_sub(self, other: Self) -> Self;
457 fn saturating_add(self, other: Self) -> Self;
458 fn saturating_sub(self, other: Self) -> Self;
459 fn wrapping_mul(self, other: Self) -> Self;
460 fn high_mul(self, other: Self) -> Self;
461 fn count_ones(self) -> Self;
462 fn count_zeros(self) -> Self;
463 }
464 pub trait SignedInteger: Integer {
465 fn wrapping_abs(self) -> Self;
466 }
467 pub trait FloatingPoint {
468 fn recip(self) -> Self;
469 fn to_degrees(self) -> Self;
470 fn to_radians(self) -> Self;
471 fn min_naive(self, other: Self) -> Self;
472 fn max_naive(self, other: Self) -> Self;
473 }
474}
475
476macro_rules! impl_integer {
477 ($($t:ident)+) => {$(
478 impl Integer for $t {
479 #[inline]
480 fn wrapping_add(self, other: Self) -> Self { $t::wrapping_add(self, other) }
481 #[inline]
482 fn wrapping_sub(self, other: Self) -> Self { $t::wrapping_sub(self, other) }
483 #[inline]
484 fn wrapping_mul(self, other: Self) -> Self { $t::wrapping_mul(self, other) }
485 #[inline]
486 fn high_mul(self, other: Self) -> Self { <$t as HighMul>::high_mul(self, other) }
487 #[inline]
488 fn saturating_add(self, other: Self) -> Self { $t::saturating_add(self, other) }
489 #[inline]
490 fn saturating_sub(self, other: Self) -> Self { $t::saturating_sub(self, other) }
491 #[inline]
492 fn count_ones(self) -> Self { $t::count_ones(self) as _ }
493 #[inline]
494 fn count_zeros(self) -> Self { $t::count_zeros(self) as _ }
495 }
496 )+};
497}
498impl_integer!(u8 u16 u32 u64);
499macro_rules! impl_signed_integer {
500 ($($t:ident)+) => {$(
501 impl_integer!($t);
502 impl SignedInteger for $t {
503 #[inline]
504 fn wrapping_abs(self) -> Self { $t::wrapping_abs(self) }
505 }
506 )+};
507}
508impl_signed_integer!(i8 i16 i32 i64);
509macro_rules! impl_floating_point {
510 ($($t:ident)+) => {$(
511 impl FloatingPoint for $t {
512 #[inline]
513 fn recip(self) -> Self { $t::recip(self) }
514 #[inline]
515 fn to_degrees(self) -> Self { $t::to_degrees(self) }
516 #[inline]
517 fn to_radians(self) -> Self { $t::to_radians(self) }
518 #[inline]
519 fn min_naive(self, other: Self) -> Self {
520 if self < other { self } else { other }
523 }
524 #[inline]
525 fn max_naive(self, other: Self) -> Self {
526 if self > other { self } else { other }
529 }
530 }
531 )+};
532}
533impl_floating_point!(f32 f64);
534
535macro_rules! define_mask_types {
536 ($($t:ident $p:ident)+) => {$(
537 impl From<bool> for $t {
538 #[inline]
539 fn from(b: bool) -> Self {
540 if b { $t::True } else { $t::False }
541 }
542 }
543 impl From<$t> for bool {
544 #[inline]
545 fn from(m: $t) -> bool {
546 match m {
547 $t::False => false,
548 $t::True => true,
549 }
550 }
551 }
552 impl From<$t> for $p {
553 #[inline]
554 fn from(m: $t) -> $p {
555 m as $p
556 }
557 }
558 impl Default for $t {
559 #[inline]
560 fn default() -> Self { $t::False }
561 }
562 impl array_utils::Zero for $t {
563 const ZERO: Self = $t::False;
564 }
565 #[repr($p)]
566 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
567 #[allow(non_camel_case_types)]
568 pub enum $t {
569 False = (0 as $p),
570 True = !(0 as $p),
571 }
572 )+};
573}
574define_mask_types!(
575 m8 u8
576 m16 u16
577 m32 u32
578 m64 u64
579);
580
581trait HighMul {
582 fn high_mul(self, other: Self) -> Self;
583}
584macro_rules! impl_high_mul {
585 ($($t:ident $t2:ident)+) => {$(
586 impl HighMul for $t {
587 #[inline]
588 fn high_mul(self, other: Self) -> Self {
589 let wide = (self as $t2) * (other as $t2);
590 let high = wide >> (mem::size_of::<$t>() * 8);
591 high as $t
592 }
593 }
594 )+};
595}
596impl_high_mul!(
597 u8 u16
598 u16 u32
599 u32 u64
600 u64 u128
601 i8 i16
602 i16 i32
603 i32 i64
604 i64 i128
605);
606
607mod array_utils;
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn it_works() {
615 let ones = i32x4::splat(1);
616 assert_eq!(ones[..], [1, 1, 1, 1]);
617
618 let a = i32x4::from([1, 2, 3, 4]);
619 let b = i32x4::from([45, 56, 78, 89]);
620 let c = b.wrapping_sub(a);
621 assert_eq!(c[..], [44, 54, 75, 85]);
622 let d = c.wrapping_add(Simd::splat(10));
623 assert_eq!(d[..], [54, 64, 85, 95]);
624 }
625
626 #[test]
627 fn defaults() {
628 i8x8::default();
629 i8x16::default();
630 i8x32::default();
631 i8x64::default();
632 }
633
634 #[test]
635 fn mask_comparison() {
636 assert!(m16::False < m16::True);
637 assert!(m16::False <= m16::True);
638 }
639}