scrypt_opt/salsa20/
mod.rs

1#![allow(
2    unused,
3    reason = "APIs that allow switching cores in code are not exposed to the public API, yet"
4)]
5
6#[cfg(target_arch = "x86_64")]
7pub(crate) mod x86_64;
8
9use generic_array::{
10    ArrayLength, GenericArray,
11    sequence::GenericSequence,
12    typenum::{U1, U2},
13};
14
15#[cfg(feature = "portable-simd")]
16#[allow(unused_imports)]
17use core::simd::{Swizzle as _, num::SimdUint, u32x4, u32x8, u32x16};
18
19#[allow(
20    unused_imports,
21    reason = "rust-analyzer doesn't consider -Ctarget-feature, silencing warnings"
22)]
23use crate::{
24    Align64,
25    simd::{Compose, ConcatLo, ExtractU32x2, FlipTable16, Inverse, Swizzle},
26};
27
28macro_rules! quarter_words {
29    ($w:expr, $a:literal, $b:literal, $c:literal, $d:literal) => {
30        $w[$b] ^= $w[$a].wrapping_add($w[$d]).rotate_left(7);
31        $w[$c] ^= $w[$b].wrapping_add($w[$a]).rotate_left(9);
32        $w[$d] ^= $w[$c].wrapping_add($w[$b]).rotate_left(13);
33        $w[$a] ^= $w[$d].wrapping_add($w[$c]).rotate_left(18);
34    };
35}
36
37/// Pivot to column-major order (A, B, D, C)
38struct Pivot;
39
40impl Swizzle<16> for Pivot {
41    const INDEX: [usize; 16] = [0, 5, 10, 15, 4, 9, 14, 3, 12, 1, 6, 11, 8, 13, 2, 7];
42}
43
44/// Round shuffle the first 4 lanes of a vector of u32
45#[allow(unused, reason = "rust-analyzer spam, actually used")]
46struct RoundShuffleAbdc;
47
48impl Swizzle<16> for RoundShuffleAbdc {
49    const INDEX: [usize; 16] = const {
50        let mut index = [0; 16];
51        let mut i = 0;
52        while i < 4 {
53            index[i] = i;
54            i += 1;
55        }
56        while i < 8 {
57            index[i] = 8 + (i + 1) % 4;
58            i += 1;
59        }
60        while i < 12 {
61            index[i] = 4 + (i + 3) % 4;
62            i += 1;
63        }
64        while i < 16 {
65            index[i] = 12 + (i + 2) % 4;
66            i += 1;
67        }
68        index
69    };
70}
71
72#[cfg(feature = "portable-simd")]
73impl core::simd::Swizzle<16> for Pivot {
74    const INDEX: [usize; 16] = <Self as Swizzle<16>>::INDEX;
75}
76
77/// A trait for block types
78pub trait BlockType: Clone + Copy {
79    /// Read a block from a pointer
80    unsafe fn read_from_ptr(ptr: *const Self) -> Self;
81    /// Write a block to a pointer
82    unsafe fn write_to_ptr(self, ptr: *mut Self);
83    /// XOR a block with another block
84    fn xor_with(&mut self, other: Self);
85}
86
87#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
88impl BlockType for core::arch::x86_64::__m512i {
89    #[inline(always)]
90    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
91        unsafe { core::ptr::read(ptr.cast()) }
92    }
93    #[inline(always)]
94    unsafe fn write_to_ptr(self, ptr: *mut Self) {
95        unsafe { core::ptr::write(ptr.cast(), self) }
96    }
97    #[inline(always)]
98    fn xor_with(&mut self, other: Self) {
99        use core::arch::x86_64::*;
100        unsafe {
101            *self = _mm512_xor_si512(*self, other);
102        }
103    }
104}
105
106#[cfg(target_arch = "x86_64")]
107impl BlockType for [core::arch::x86_64::__m256i; 2] {
108    #[inline(always)]
109    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
110        unsafe { core::ptr::read(ptr) }
111    }
112    #[inline(always)]
113    unsafe fn write_to_ptr(self, ptr: *mut Self) {
114        unsafe { core::ptr::write(ptr, self) };
115    }
116    #[inline(always)]
117    fn xor_with(&mut self, other: Self) {
118        use core::arch::x86_64::*;
119        unsafe {
120            self[0] = _mm256_xor_si256(self[0], other[0]);
121            self[1] = _mm256_xor_si256(self[1], other[1]);
122        }
123    }
124}
125
126impl BlockType for Align64<[u32; 16]> {
127    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
128        unsafe { ptr.read() }
129    }
130    unsafe fn write_to_ptr(mut self, ptr: *mut Self) {
131        unsafe { ptr.write(self) }
132    }
133    fn xor_with(&mut self, other: Self) {
134        for i in 0..16 {
135            self.0[i] ^= other.0[i];
136        }
137    }
138}
139
140#[cfg(feature = "portable-simd")]
141impl BlockType for core::simd::u32x16 {
142    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
143        unsafe { ptr.read() }
144    }
145    unsafe fn write_to_ptr(self, ptr: *mut Self) {
146        unsafe { ptr.write(self) }
147    }
148    fn xor_with(&mut self, other: Self) {
149        *self ^= other;
150    }
151}
152
153/// A trait for salsa20 block types
154pub trait Salsa20 {
155    /// The number of lanes
156    type Lanes: ArrayLength;
157    /// The block type
158    type Block: BlockType;
159
160    /// Shuffle data into optimal representation
161    fn shuffle_in(_ptr: &mut Align64<[u32; 16]>) {}
162
163    /// Shuffle data out of optimal representation
164    fn shuffle_out(_ptr: &mut Align64<[u32; 16]>) {}
165
166    /// Read block(s)
167    fn read(ptr: GenericArray<&Self::Block, Self::Lanes>) -> Self;
168    /// Write block(s) back
169    ///
170    /// The original/saved value must be present in the pointer.
171    fn write(&self, ptr: GenericArray<&mut Self::Block, Self::Lanes>);
172    /// Apply the keystream to the block(s)
173    fn keystream<const ROUND_PAIRS: usize>(&mut self);
174}
175
176/// A scalar solution
177#[allow(unused, reason = "Currently unused, but handy for testing")]
178pub struct BlockScalar<Lanes: ArrayLength> {
179    w: GenericArray<[u32; 16], Lanes>,
180}
181
182impl<Lanes: ArrayLength> Salsa20 for BlockScalar<Lanes> {
183    type Lanes = Lanes;
184    type Block = Align64<[u32; 16]>;
185
186    #[cfg(target_endian = "big")]
187    fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
188        for i in 0..16 {
189            ptr.0[i] = ptr.0[i].swap_bytes();
190        }
191    }
192
193    #[cfg(target_endian = "big")]
194    fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
195        for i in 0..16 {
196            ptr.0[i] = ptr.0[i].swap_bytes();
197        }
198    }
199
200    #[inline(always)]
201    fn read(ptr: GenericArray<&Self::Block, Lanes>) -> Self {
202        Self {
203            w: GenericArray::generate(|i| **ptr[i]),
204        }
205    }
206
207    #[inline(always)]
208    fn write(&self, mut ptr: GenericArray<&mut Self::Block, Lanes>) {
209        for i in 0..Lanes::USIZE {
210            for j in 0..16 {
211                ptr[i][j] = self.w[i][j];
212            }
213        }
214    }
215
216    #[inline(always)]
217    fn keystream<const ROUND_PAIRS: usize>(&mut self) {
218        let mut w = self.w.clone();
219
220        for i in 0..Lanes::USIZE {
221            for _ in 0..ROUND_PAIRS {
222                quarter_words!(w[i], 0, 4, 8, 12);
223                quarter_words!(w[i], 5, 9, 13, 1);
224                quarter_words!(w[i], 10, 14, 2, 6);
225                quarter_words!(w[i], 15, 3, 7, 11);
226
227                quarter_words!(w[i], 0, 1, 2, 3);
228                quarter_words!(w[i], 5, 6, 7, 4);
229                quarter_words!(w[i], 10, 11, 8, 9);
230                quarter_words!(w[i], 15, 12, 13, 14);
231            }
232        }
233
234        for i in 0..Lanes::USIZE {
235            for j in 0..16 {
236                self.w[i][j] = self.w[i][j].wrapping_add(w[i][j]);
237            }
238        }
239    }
240}
241
242#[cfg(feature = "portable-simd")]
243/// A solution for 1 lane of 128-bit blocks using portable SIMD
244pub struct BlockPortableSimd {
245    a: u32x4,
246    b: u32x4,
247    c: u32x4,
248    d: u32x4,
249}
250
251#[cfg(feature = "portable-simd")]
252#[inline(always)]
253fn simd_rotate_left<const N: usize, const D: u32>(
254    x: core::simd::Simd<u32, N>,
255) -> core::simd::Simd<u32, N>
256where
257    core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
258{
259    let shifted = x << D;
260    let shifted2 = x >> (32 - D);
261    shifted | shifted2
262}
263
264#[cfg(feature = "portable-simd")]
265impl Salsa20 for BlockPortableSimd {
266    type Lanes = U1;
267    type Block = u32x16;
268
269    #[inline(always)]
270    fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
271        let pivoted = Pivot::swizzle(u32x16::from_array(ptr.0));
272
273        #[cfg(target_endian = "big")]
274        let pivoted = pivoted.swap_bytes();
275
276        ptr.0 = *pivoted.as_array();
277    }
278
279    #[inline(always)]
280    fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
281        let pivoted = Inverse::<_, Pivot>::swizzle(u32x16::from_array(ptr.0));
282
283        #[cfg(target_endian = "big")]
284        let pivoted = pivoted.swap_bytes();
285
286        ptr.0 = *pivoted.as_array();
287    }
288
289    #[inline(always)]
290    fn read(ptr: GenericArray<&Self::Block, U1>) -> Self {
291        let a = ptr[0].extract::<0, 4>();
292        let b = ptr[0].extract::<4, 4>();
293        let d = ptr[0].extract::<8, 4>();
294        let c = ptr[0].extract::<12, 4>();
295
296        Self { a, b, c, d }
297    }
298
299    #[inline(always)]
300    fn write(&self, mut ptr: GenericArray<&mut Self::Block, U1>) {
301        use crate::simd::Identity;
302
303        // straighten vectors
304        let ab = Identity::<8>::concat_swizzle(self.a, self.b);
305        let dc = Identity::<8>::concat_swizzle(self.d, self.c);
306        let abdc = Identity::<16>::concat_swizzle(ab, dc);
307
308        *ptr[0] += abdc;
309    }
310
311    #[inline(always)]
312    fn keystream<const ROUND_PAIRS: usize>(&mut self) {
313        if ROUND_PAIRS == 0 {
314            return;
315        }
316
317        for _ in 0..(ROUND_PAIRS * 2) {
318            self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
319            self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
320            self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
321            self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
322
323            self.d = self.d.rotate_elements_left::<1>();
324            self.c = self.c.rotate_elements_left::<2>();
325            self.b = self.b.rotate_elements_left::<3>();
326            (self.b, self.d) = (self.d, self.b);
327        }
328    }
329}
330
331#[cfg(feature = "portable-simd")]
332/// A solution for 2 lanes of 128-bit blocks using portable SIMD
333pub struct BlockPortableSimd2 {
334    a: u32x8,
335    b: u32x8,
336    c: u32x8,
337    d: u32x8,
338}
339
340#[cfg(feature = "portable-simd")]
341impl Salsa20 for BlockPortableSimd2 {
342    type Lanes = U2;
343    type Block = u32x16;
344
345    #[inline(always)]
346    fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
347        BlockPortableSimd::shuffle_in(ptr);
348    }
349
350    #[inline(always)]
351    fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
352        BlockPortableSimd::shuffle_out(ptr);
353    }
354
355    #[inline(always)]
356    fn read(ptr: GenericArray<&Self::Block, U2>) -> Self {
357        let buffer0_ab = core::simd::simd_swizzle!(*ptr[0], [0, 1, 2, 3, 4, 5, 6, 7]);
358        let buffer0_dc = core::simd::simd_swizzle!(*ptr[0], [8, 9, 10, 11, 12, 13, 14, 15]);
359        let buffer1_ab = core::simd::simd_swizzle!(*ptr[1], [0, 1, 2, 3, 4, 5, 6, 7]);
360        let buffer1_dc = core::simd::simd_swizzle!(*ptr[1], [8, 9, 10, 11, 12, 13, 14, 15]);
361
362        let a = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [0, 1, 2, 3, 8, 9, 10, 11]);
363        let b = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [4, 5, 6, 7, 12, 13, 14, 15]);
364        let d = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [0, 1, 2, 3, 8, 9, 10, 11]);
365        let c = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [4, 5, 6, 7, 12, 13, 14, 15]);
366
367        Self { a, b, c, d }
368    }
369
370    #[inline(always)]
371    fn write(&self, mut ptr: GenericArray<&mut Self::Block, U2>) {
372        use crate::simd::Identity;
373
374        // pick out elements from each buffer
375        // this shuffle automatically gets composed by LLVM
376
377        let a0b0 = core::simd::simd_swizzle!(self.a, self.b, [0, 1, 2, 3, 8, 9, 10, 11]);
378        let a1b1 = core::simd::simd_swizzle!(self.a, self.b, [4, 5, 6, 7, 12, 13, 14, 15]);
379        let d0c0 = core::simd::simd_swizzle!(self.d, self.c, [0, 1, 2, 3, 8, 9, 10, 11]);
380        let d1c1 = core::simd::simd_swizzle!(self.d, self.c, [4, 5, 6, 7, 12, 13, 14, 15]);
381
382        *ptr[0] += Identity::<16>::concat_swizzle(a0b0, d0c0);
383        *ptr[1] += Identity::<16>::concat_swizzle(a1b1, d1c1);
384    }
385
386    #[inline(always)]
387    fn keystream<const ROUND_PAIRS: usize>(&mut self) {
388        if ROUND_PAIRS == 0 {
389            return;
390        }
391
392        for _ in 0..(ROUND_PAIRS * 2) {
393            self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
394            self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
395            self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
396            self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
397
398            self.d = core::simd::simd_swizzle!(self.d, [1, 2, 3, 0, 5, 6, 7, 4]);
399            self.c = core::simd::simd_swizzle!(self.c, [2, 3, 0, 1, 6, 7, 4, 5]);
400            self.b = core::simd::simd_swizzle!(self.b, [3, 0, 1, 2, 7, 4, 5, 6]);
401            (self.b, self.d) = (self.d, self.b);
402        }
403    }
404}
405
406#[cfg(test)]
407#[allow(unused_imports)]
408mod tests {
409    use generic_array::{GenericArray, typenum::U1};
410
411    use super::*;
412
413    pub(crate) fn test_shuffle_in_out_identity<S: Salsa20>()
414    where
415        S::Block: BlockType,
416    {
417        fn lfsr(x: &mut u32) -> u32 {
418            *x = *x ^ (*x >> 2);
419            *x = *x ^ (*x >> 3);
420            *x = *x ^ (*x >> 5);
421            *x
422        }
423
424        let mut state = 0;
425
426        for _ in 0..5 {
427            let test_input = Align64(core::array::from_fn(|i| lfsr(&mut state) + i as u32));
428
429            let mut result = test_input.clone();
430            S::shuffle_in(&mut result);
431            S::shuffle_out(&mut result);
432            assert_eq!(result, test_input);
433        }
434    }
435
436    #[cfg(feature = "portable-simd")]
437    fn test_keystream_portable_simd<const ROUND_PAIRS: usize>() {
438        test_shuffle_in_out_identity::<BlockPortableSimd>();
439
440        let test_input: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
441        let mut expected = test_input.clone();
442
443        let mut test_input_scalar_shuffled = test_input.clone();
444        BlockScalar::<U1>::shuffle_in(&mut test_input_scalar_shuffled);
445        let mut block =
446            BlockScalar::<U1>::read(GenericArray::from_array([&test_input_scalar_shuffled]));
447        block.keystream::<ROUND_PAIRS>();
448        block.write(GenericArray::from_array([&mut expected]));
449        BlockScalar::<U1>::shuffle_out(&mut expected);
450
451        let mut test_input_shuffled = test_input.clone();
452
453        BlockPortableSimd::shuffle_in(&mut test_input_shuffled);
454        let mut result = u32x16::from_array(*test_input_shuffled);
455
456        let mut block_v = BlockPortableSimd::read(GenericArray::from_array([&result]));
457        block_v.keystream::<ROUND_PAIRS>();
458        block_v.write(GenericArray::from_array([&mut result]));
459
460        let mut output = Align64(result.to_array());
461        BlockPortableSimd::shuffle_out(&mut output);
462
463        assert_eq!(output, expected);
464    }
465
466    #[cfg(feature = "portable-simd")]
467    fn test_keystream_portable_simd2<const ROUND_PAIRS: usize>() {
468        test_shuffle_in_out_identity::<BlockPortableSimd2>();
469
470        let test_input0: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
471        let test_input1: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32 + 16));
472        let mut expected0 = test_input0.clone();
473        let mut expected1 = test_input1.clone();
474
475        let mut test_input0_scalar_shuffled = test_input0.clone();
476        let mut test_input1_scalar_shuffled = test_input1.clone();
477        BlockScalar::<U1>::shuffle_in(&mut test_input0_scalar_shuffled);
478        BlockScalar::<U1>::shuffle_in(&mut test_input1_scalar_shuffled);
479
480        let mut block0 =
481            BlockScalar::<U1>::read(GenericArray::from_array([&test_input0_scalar_shuffled]));
482        let mut block1 =
483            BlockScalar::<U1>::read(GenericArray::from_array([&test_input1_scalar_shuffled]));
484        block0.keystream::<ROUND_PAIRS>();
485        block1.keystream::<ROUND_PAIRS>();
486        block0.write(GenericArray::from_array([&mut expected0]));
487        block1.write(GenericArray::from_array([&mut expected1]));
488        BlockScalar::<U1>::shuffle_out(&mut expected0);
489        BlockScalar::<U1>::shuffle_out(&mut expected1);
490
491        let mut test_input0_shuffled = test_input0.clone();
492        let mut test_input1_shuffled = test_input1.clone();
493        BlockPortableSimd2::shuffle_in(&mut test_input0_shuffled);
494        BlockPortableSimd2::shuffle_in(&mut test_input1_shuffled);
495
496        let mut result0 = u32x16::from_array(*test_input0_shuffled);
497        let mut result1 = u32x16::from_array(*test_input1_shuffled);
498
499        let mut block_v0 = BlockPortableSimd2::read(GenericArray::from_array([&result0, &result1]));
500        block_v0.keystream::<ROUND_PAIRS>();
501        block_v0.write(GenericArray::from_array([&mut result0, &mut result1]));
502
503        let mut output0 = Align64(result0.to_array());
504        let mut output1 = Align64(result1.to_array());
505
506        BlockPortableSimd2::shuffle_out(&mut output0);
507        BlockPortableSimd2::shuffle_out(&mut output1);
508
509        assert_eq!(output0, expected0);
510        assert_eq!(output1, expected1);
511    }
512
513    #[cfg(feature = "portable-simd")]
514    #[test]
515    fn test_keystream_portable_simd_0() {
516        test_keystream_portable_simd::<0>();
517    }
518
519    #[cfg(feature = "portable-simd")]
520    #[test]
521    fn test_keystream_portable_simd_2() {
522        test_keystream_portable_simd::<1>();
523    }
524
525    #[cfg(feature = "portable-simd")]
526    #[test]
527    fn test_keystream_portable_simd_8() {
528        test_keystream_portable_simd::<4>();
529    }
530
531    #[cfg(feature = "portable-simd")]
532    #[test]
533    fn test_keystream_portable_simd_10() {
534        test_keystream_portable_simd::<5>();
535    }
536
537    #[cfg(feature = "portable-simd")]
538    #[test]
539    fn test_keystream_portable_simd2_0() {
540        test_keystream_portable_simd2::<0>();
541    }
542
543    #[cfg(feature = "portable-simd")]
544    #[test]
545    fn test_keystream_portable_simd2_2() {
546        test_keystream_portable_simd2::<1>();
547    }
548
549    #[cfg(feature = "portable-simd")]
550    #[test]
551    fn test_keystream_portable_simd2_8() {
552        test_keystream_portable_simd2::<4>();
553    }
554
555    #[cfg(feature = "portable-simd")]
556    #[test]
557    fn test_keystream_portable_simd2_10() {
558        test_keystream_portable_simd2::<5>();
559    }
560}