Skip to main content

zenjxl_decoder_simd/
scalar.rs

1// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6use std::num::Wrapping;
7
8use crate::{U32SimdVec, f16, impl_f32_array_interface};
9
10use super::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask, U8SimdVec, U16SimdVec};
11
12#[derive(Clone, Copy, Debug)]
13pub struct ScalarDescriptor;
14
15impl ScalarDescriptor {
16    #[inline]
17    pub fn from_token(_token: archmage::ScalarToken) -> Self {
18        Self
19    }
20}
21
22impl SimdDescriptor for ScalarDescriptor {
23    type F32Vec = f32;
24    type I32Vec = Wrapping<i32>;
25    type U32Vec = Wrapping<u32>;
26    type U8Vec = u8;
27    type U16Vec = u16;
28    type Mask = bool;
29    type Bf16Table8 = [f32; 8];
30
31    type Descriptor256 = Self;
32    type Descriptor128 = Self;
33
34    #[inline]
35    fn maybe_downgrade_256bit(self) -> Self::Descriptor256 {
36        self
37    }
38
39    #[inline]
40    fn maybe_downgrade_128bit(self) -> Self::Descriptor128 {
41        self
42    }
43
44    #[inline]
45    fn new() -> Option<Self> {
46        Some(Self)
47    }
48
49    fn call<R>(self, f: impl FnOnce(Self) -> R) -> R {
50        // No special features needed for scalar implementation
51        f(self)
52    }
53}
54
55impl F32SimdVec for f32 {
56    type Descriptor = ScalarDescriptor;
57
58    const LEN: usize = 1;
59
60    #[inline(always)]
61    fn load(_d: Self::Descriptor, mem: &[f32]) -> Self {
62        mem[0]
63    }
64
65    #[inline(always)]
66    fn store(&self, mem: &mut [f32]) {
67        mem[0] = *self;
68    }
69
70    #[inline(always)]
71    fn store_interleaved_2(a: Self, b: Self, dest: &mut [f32]) {
72        dest[0] = a;
73        dest[1] = b;
74    }
75
76    #[inline(always)]
77    fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [f32]) {
78        dest[0] = a;
79        dest[1] = b;
80        dest[2] = c;
81    }
82
83    #[inline(always)]
84    fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [f32]) {
85        dest[0] = a;
86        dest[1] = b;
87        dest[2] = c;
88        dest[3] = d;
89    }
90
91    #[inline(always)]
92    fn store_interleaved_8(
93        a: Self,
94        b: Self,
95        c: Self,
96        d: Self,
97        e: Self,
98        f: Self,
99        g: Self,
100        h: Self,
101        dest: &mut [f32],
102    ) {
103        dest[0] = a;
104        dest[1] = b;
105        dest[2] = c;
106        dest[3] = d;
107        dest[4] = e;
108        dest[5] = f;
109        dest[6] = g;
110        dest[7] = h;
111    }
112
113    #[inline(always)]
114    fn load_deinterleaved_2(_d: Self::Descriptor, src: &[f32]) -> (Self, Self) {
115        (src[0], src[1])
116    }
117
118    #[inline(always)]
119    fn load_deinterleaved_3(_d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self) {
120        (src[0], src[1], src[2])
121    }
122
123    #[inline(always)]
124    fn load_deinterleaved_4(_d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self, Self) {
125        (src[0], src[1], src[2], src[3])
126    }
127
128    #[inline(always)]
129    fn mul_add(self, mul: Self, add: Self) -> Self {
130        (self * mul) + add
131    }
132
133    #[inline(always)]
134    fn neg_mul_add(self, mul: Self, add: Self) -> Self {
135        -(self * mul) + add
136    }
137
138    #[inline(always)]
139    fn splat(_d: Self::Descriptor, v: f32) -> Self {
140        v
141    }
142
143    #[inline(always)]
144    fn zero(_d: Self::Descriptor) -> Self {
145        0.0
146    }
147
148    #[inline(always)]
149    fn abs(self) -> Self {
150        self.abs()
151    }
152
153    #[inline(always)]
154    fn floor(self) -> Self {
155        self.floor()
156    }
157
158    #[inline(always)]
159    fn sqrt(self) -> Self {
160        self.sqrt()
161    }
162
163    #[inline(always)]
164    fn neg(self) -> Self {
165        -self
166    }
167
168    #[inline(always)]
169    fn copysign(self, sign: Self) -> Self {
170        self.copysign(sign)
171    }
172
173    #[inline(always)]
174    fn max(self, other: Self) -> Self {
175        self.max(other)
176    }
177
178    #[inline(always)]
179    fn min(self, other: Self) -> Self {
180        self.min(other)
181    }
182
183    #[inline(always)]
184    fn gt(self, other: Self) -> bool {
185        self > other
186    }
187
188    #[inline(always)]
189    fn as_i32(self) -> Wrapping<i32> {
190        Wrapping(self as i32)
191    }
192
193    #[inline(always)]
194    fn bitcast_to_i32(self) -> Wrapping<i32> {
195        Wrapping(self.to_bits() as i32)
196    }
197
198    #[inline(always)]
199    fn prepare_table_bf16_8(_d: Self::Descriptor, table: &[f32; 8]) -> [f32; 8] {
200        // For scalar, just copy the table
201        *table
202    }
203
204    #[inline(always)]
205    fn table_lookup_bf16_8(_d: Self::Descriptor, table: [f32; 8], indices: Wrapping<i32>) -> Self {
206        table[indices.0 as usize]
207    }
208
209    #[inline(always)]
210    fn round_store_u8(self, dest: &mut [u8]) {
211        dest[0] = self.round() as u8;
212    }
213
214    #[inline(always)]
215    fn round_store_u16(self, dest: &mut [u16]) {
216        dest[0] = self.round() as u16;
217    }
218
219    #[inline(always)]
220    fn load_f16_bits(_d: Self::Descriptor, mem: &[u16]) -> Self {
221        f16::from_bits(mem[0]).to_f32()
222    }
223
224    #[inline(always)]
225    fn store_f16_bits(self, dest: &mut [u16]) {
226        dest[0] = f16::from_f32(self).to_bits();
227    }
228
229    impl_f32_array_interface!();
230
231    #[inline(always)]
232    fn transpose_square(_d: Self::Descriptor, _data: &mut [Self::UnderlyingArray], _stride: usize) {
233        // Nothing to do.
234    }
235}
236
237impl I32SimdVec for Wrapping<i32> {
238    type Descriptor = ScalarDescriptor;
239
240    const LEN: usize = 1;
241
242    #[inline(always)]
243    fn splat(_d: Self::Descriptor, v: i32) -> Self {
244        Wrapping(v)
245    }
246
247    #[inline(always)]
248    fn load(_d: Self::Descriptor, mem: &[i32]) -> Self {
249        Wrapping(mem[0])
250    }
251
252    #[inline(always)]
253    fn store(&self, mem: &mut [i32]) {
254        mem[0] = self.0;
255    }
256
257    #[inline(always)]
258    fn abs(self) -> Self {
259        Wrapping(self.0.abs())
260    }
261
262    #[inline(always)]
263    fn as_f32(self) -> f32 {
264        self.0 as f32
265    }
266
267    #[inline(always)]
268    fn bitcast_to_f32(self) -> f32 {
269        f32::from_bits(self.0 as u32)
270    }
271
272    #[inline(always)]
273    fn bitcast_to_u32(self) -> Wrapping<u32> {
274        Wrapping(self.0 as u32)
275    }
276
277    #[inline(always)]
278    fn gt(self, other: Self) -> bool {
279        self.0 > other.0
280    }
281
282    #[inline(always)]
283    fn lt_zero(self) -> bool {
284        self.0 < 0
285    }
286
287    #[inline(always)]
288    fn eq(self, other: Self) -> bool {
289        self.0 == other.0
290    }
291
292    #[inline(always)]
293    fn eq_zero(self) -> bool {
294        self.0 == 0
295    }
296
297    #[inline(always)]
298    fn shl<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
299        Wrapping(self.0 << AMOUNT_U)
300    }
301
302    #[inline(always)]
303    fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
304        Wrapping(self.0 >> AMOUNT_U)
305    }
306
307    #[inline(always)]
308    fn mul_wide_take_high(self, rhs: Self) -> Self {
309        Wrapping(((self.0 as i64 * rhs.0 as i64) >> 32) as i32)
310    }
311
312    #[inline(always)]
313    fn store_u16(self, dest: &mut [u16]) {
314        dest[0] = self.0 as u16;
315    }
316
317    #[inline(always)]
318    fn store_u8(self, dest: &mut [u8]) {
319        dest[0] = self.0 as u8;
320    }
321}
322
323impl U32SimdVec for Wrapping<u32> {
324    type Descriptor = ScalarDescriptor;
325
326    const LEN: usize = 1;
327
328    #[inline(always)]
329    fn bitcast_to_i32(self) -> Wrapping<i32> {
330        Wrapping(self.0 as i32)
331    }
332
333    #[inline(always)]
334    fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
335        Wrapping(self.0 >> AMOUNT_U)
336    }
337}
338
339impl U8SimdVec for u8 {
340    type Descriptor = ScalarDescriptor;
341    const LEN: usize = 1;
342
343    #[inline(always)]
344    fn load(_d: Self::Descriptor, mem: &[u8]) -> Self {
345        mem[0]
346    }
347
348    #[inline(always)]
349    fn splat(_d: Self::Descriptor, v: u8) -> Self {
350        v
351    }
352
353    #[inline(always)]
354    fn store(&self, mem: &mut [u8]) {
355        mem[0] = *self;
356    }
357
358    #[inline(always)]
359    fn store_interleaved_2(a: Self, b: Self, dest: &mut [u8]) {
360        dest[0] = a;
361        dest[1] = b;
362    }
363
364    #[inline(always)]
365    fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u8]) {
366        dest[0] = a;
367        dest[1] = b;
368        dest[2] = c;
369    }
370
371    #[inline(always)]
372    fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u8]) {
373        dest[0] = a;
374        dest[1] = b;
375        dest[2] = c;
376        dest[3] = d;
377    }
378}
379
380impl U16SimdVec for u16 {
381    type Descriptor = ScalarDescriptor;
382    const LEN: usize = 1;
383
384    #[inline(always)]
385    fn load(_d: Self::Descriptor, mem: &[u16]) -> Self {
386        mem[0]
387    }
388
389    #[inline(always)]
390    fn splat(_d: Self::Descriptor, v: u16) -> Self {
391        v
392    }
393
394    #[inline(always)]
395    fn store(&self, mem: &mut [u16]) {
396        mem[0] = *self;
397    }
398
399    #[inline(always)]
400    fn store_interleaved_2(a: Self, b: Self, dest: &mut [u16]) {
401        dest[0] = a;
402        dest[1] = b;
403    }
404
405    #[inline(always)]
406    fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u16]) {
407        dest[0] = a;
408        dest[1] = b;
409        dest[2] = c;
410    }
411
412    #[inline(always)]
413    fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u16]) {
414        dest[0] = a;
415        dest[1] = b;
416        dest[2] = c;
417        dest[3] = d;
418    }
419}
420
421impl SimdMask for bool {
422    type Descriptor = ScalarDescriptor;
423
424    #[inline(always)]
425    fn if_then_else_f32(self, if_true: f32, if_false: f32) -> f32 {
426        if self { if_true } else { if_false }
427    }
428
429    #[inline(always)]
430    fn if_then_else_i32(self, if_true: Wrapping<i32>, if_false: Wrapping<i32>) -> Wrapping<i32> {
431        if self { if_true } else { if_false }
432    }
433
434    #[inline(always)]
435    fn maskz_i32(self, v: Wrapping<i32>) -> Wrapping<i32> {
436        if self { Wrapping(0) } else { v }
437    }
438
439    #[inline(always)]
440    fn all(self) -> bool {
441        self
442    }
443
444    #[inline(always)]
445    fn andnot(self, rhs: Self) -> Self {
446        (!self) & rhs
447    }
448}
449
450#[cfg(not(any(
451    target_arch = "x86_64",
452    target_arch = "aarch64",
453    target_arch = "wasm32"
454)))]
455#[macro_export]
456macro_rules! simd_function {
457    (
458        $dname:ident,
459        $descr:ident: $descr_ty:ident,
460        $(#[$($attr:meta)*])*
461        $pub:vis fn $name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )? $body: block
462    ) => {
463        #[inline(always)]
464        $(#[$($attr)*])*
465        $pub fn $name<$descr_ty: $crate::SimdDescriptor>($descr: $descr_ty, $($arg: $ty),*) $(-> $ret)? $body
466        $(#[$($attr)*])*
467        $pub fn $dname($($arg: $ty),*) $(-> $ret)? {
468            use $crate::SimdDescriptor;
469            $name($crate::ScalarDescriptor::new().unwrap(), $($arg),*)
470        }
471    };
472}
473
474#[cfg(not(any(
475    target_arch = "x86_64",
476    target_arch = "aarch64",
477    target_arch = "wasm32"
478)))]
479#[macro_export]
480macro_rules! test_all_instruction_sets {
481    (
482        $name:ident
483    ) => {
484        paste::paste! {
485            #[test]
486            fn [<$name _scalar>]() {
487                use $crate::SimdDescriptor;
488                $name($crate::ScalarDescriptor::new().unwrap())
489            }
490        }
491    };
492}
493
494#[cfg(not(any(
495    target_arch = "x86_64",
496    target_arch = "aarch64",
497    target_arch = "wasm32"
498)))]
499#[macro_export]
500macro_rules! bench_all_instruction_sets {
501    (
502        $name:ident,
503        $criterion:ident
504    ) => {
505        use $crate::SimdDescriptor;
506        $name(
507            $crate::ScalarDescriptor::new().unwrap(),
508            $criterion,
509            "scalar",
510        );
511    };
512}