scrypt_opt/salsa20/
mod.rs

1#![allow(
2    unused,
3    reason = "APIs that allow switching cores in code are not exposed to the public API, yet"
4)]
5
6#[cfg(target_arch = "x86_64")]
7pub(crate) mod x86_64;
8
9use generic_array::{
10    ArrayLength, GenericArray,
11    sequence::GenericSequence,
12    typenum::{IsLessOrEqual, U1, U2},
13};
14
15#[cfg(feature = "portable-simd")]
16#[allow(unused_imports)]
17use core::simd::{Swizzle as _, num::SimdUint, u32x4, u32x8, u32x16};
18
19#[allow(
20    unused_imports,
21    reason = "rust-analyzer doesn't consider -Ctarget-feature, silencing warnings"
22)]
23use crate::{
24    Align64,
25    simd::{Compose, ConcatLo, ExtractU32x2, FlipTable16, Inverse, Swizzle},
26};
27
28macro_rules! quarter_words {
29    ($w:expr, $a:literal, $b:literal, $c:literal, $d:literal) => {
30        $w[$b] ^= $w[$a].wrapping_add($w[$d]).rotate_left(7);
31        $w[$c] ^= $w[$b].wrapping_add($w[$a]).rotate_left(9);
32        $w[$d] ^= $w[$c].wrapping_add($w[$b]).rotate_left(13);
33        $w[$a] ^= $w[$d].wrapping_add($w[$c]).rotate_left(18);
34    };
35}
36
37/// Pivot to column-major order (A, B, D, C)
38struct Pivot;
39
40impl Swizzle<16> for Pivot {
41    const INDEX: [usize; 16] = [0, 5, 10, 15, 4, 9, 14, 3, 12, 1, 6, 11, 8, 13, 2, 7];
42}
43
44/// Round shuffle the first 4 lanes of a vector of u32
45#[allow(unused, reason = "rust-analyzer spam, actually used")]
46struct RoundShuffleAbdc;
47
48impl Swizzle<16> for RoundShuffleAbdc {
49    const INDEX: [usize; 16] = const {
50        let mut index = [0; 16];
51        let mut i = 0;
52        while i < 4 {
53            index[i] = i;
54            i += 1;
55        }
56        while i < 8 {
57            index[i] = 8 + (i + 1) % 4;
58            i += 1;
59        }
60        while i < 12 {
61            index[i] = 4 + (i + 3) % 4;
62            i += 1;
63        }
64        while i < 16 {
65            index[i] = 12 + (i + 2) % 4;
66            i += 1;
67        }
68        index
69    };
70}
71
72#[cfg(feature = "portable-simd")]
73impl core::simd::Swizzle<16> for Pivot {
74    const INDEX: [usize; 16] = <Self as Swizzle<16>>::INDEX;
75}
76
77/// A trait for block types
78pub trait BlockType: Clone + Copy {
79    /// Read a block from a pointer
80    unsafe fn read_from_ptr(ptr: *const Self) -> Self;
81    /// Write a block to a pointer
82    unsafe fn write_to_ptr(self, ptr: *mut Self);
83    /// XOR a block with another block
84    fn xor_with(&mut self, other: Self);
85}
86
87#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
88impl BlockType for core::arch::x86_64::__m512i {
89    #[inline(always)]
90    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
91        unsafe { core::ptr::read(ptr.cast()) }
92    }
93    #[inline(always)]
94    unsafe fn write_to_ptr(self, ptr: *mut Self) {
95        unsafe { core::ptr::write(ptr.cast(), self) }
96    }
97    #[inline(always)]
98    fn xor_with(&mut self, other: Self) {
99        use core::arch::x86_64::*;
100        unsafe {
101            *self = _mm512_xor_si512(*self, other);
102        }
103    }
104}
105
106#[cfg(target_arch = "x86_64")]
107impl BlockType for [core::arch::x86_64::__m256i; 2] {
108    #[inline(always)]
109    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
110        unsafe { core::ptr::read(ptr) }
111    }
112    #[inline(always)]
113    unsafe fn write_to_ptr(self, ptr: *mut Self) {
114        unsafe { core::ptr::write(ptr, self) };
115    }
116    #[inline(always)]
117    fn xor_with(&mut self, other: Self) {
118        use core::arch::x86_64::*;
119        unsafe {
120            self[0] = _mm256_xor_si256(self[0], other[0]);
121            self[1] = _mm256_xor_si256(self[1], other[1]);
122        }
123    }
124}
125
126#[cfg(target_arch = "x86_64")]
127impl BlockType for [core::arch::x86_64::__m128i; 4] {
128    #[inline(always)]
129    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
130        unsafe { core::ptr::read(ptr) }
131    }
132    #[inline(always)]
133    unsafe fn write_to_ptr(self, ptr: *mut Self) {
134        unsafe { core::ptr::write(ptr, self) };
135    }
136    #[inline(always)]
137    fn xor_with(&mut self, other: Self) {
138        use core::arch::x86_64::*;
139        unsafe {
140            self[0] = _mm_xor_si128(self[0], other[0]);
141            self[1] = _mm_xor_si128(self[1], other[1]);
142            self[2] = _mm_xor_si128(self[2], other[2]);
143            self[3] = _mm_xor_si128(self[3], other[3]);
144        }
145    }
146}
147
148impl BlockType for Align64<[u32; 16]> {
149    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
150        unsafe { ptr.read() }
151    }
152    unsafe fn write_to_ptr(mut self, ptr: *mut Self) {
153        unsafe { ptr.write(self) }
154    }
155    fn xor_with(&mut self, other: Self) {
156        for i in 0..16 {
157            self.0[i] ^= other.0[i];
158        }
159    }
160}
161
162#[cfg(feature = "portable-simd")]
163impl BlockType for core::simd::u32x16 {
164    unsafe fn read_from_ptr(ptr: *const Self) -> Self {
165        unsafe { ptr.read() }
166    }
167    unsafe fn write_to_ptr(self, ptr: *mut Self) {
168        unsafe { ptr.write(self) }
169    }
170    fn xor_with(&mut self, other: Self) {
171        *self ^= other;
172    }
173}
174
175/// A trait for salsa20 block types
176pub trait Salsa20 {
177    /// The number of lanes
178    type Lanes: ArrayLength;
179    /// The block type
180    type Block: BlockType;
181
182    /// Shuffle data into optimal representation
183    fn shuffle_in(_ptr: &mut Align64<[u32; 16]>) {}
184
185    /// Shuffle data out of optimal representation
186    fn shuffle_out(_ptr: &mut Align64<[u32; 16]>) {}
187
188    /// Read block(s)
189    fn read(ptr: GenericArray<&Self::Block, Self::Lanes>) -> Self;
190    /// Write block(s) back
191    ///
192    /// The original/saved value must be present in the pointer.
193    fn write(&self, ptr: GenericArray<&mut Self::Block, Self::Lanes>);
194    /// Apply the keystream to the block(s)
195    fn keystream<const ROUND_PAIRS: usize>(&mut self);
196}
197
198/// A scalar solution
199#[allow(unused, reason = "Currently unused, but handy for testing")]
200pub struct BlockScalar<Lanes: ArrayLength> {
201    w: GenericArray<[u32; 16], Lanes>,
202}
203
204impl<Lanes: ArrayLength> Salsa20 for BlockScalar<Lanes> {
205    type Lanes = Lanes;
206    type Block = Align64<[u32; 16]>;
207
208    #[cfg(target_endian = "big")]
209    fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
210        for i in 0..16 {
211            ptr.0[i] = ptr.0[i].swap_bytes();
212        }
213    }
214
215    #[cfg(target_endian = "big")]
216    fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
217        for i in 0..16 {
218            ptr.0[i] = ptr.0[i].swap_bytes();
219        }
220    }
221
222    #[inline(always)]
223    fn read(ptr: GenericArray<&Self::Block, Lanes>) -> Self {
224        Self {
225            w: GenericArray::generate(|i| **ptr[i]),
226        }
227    }
228
229    #[inline(always)]
230    fn write(&self, mut ptr: GenericArray<&mut Self::Block, Lanes>) {
231        for i in 0..Lanes::USIZE {
232            for j in 0..16 {
233                ptr[i][j] = self.w[i][j];
234            }
235        }
236    }
237
238    fn keystream<const ROUND_PAIRS: usize>(&mut self) {
239        let mut w = self.w.clone();
240
241        for _ in 0..ROUND_PAIRS {
242            for i in 0..Lanes::USIZE {
243                quarter_words!(w[i], 0, 4, 8, 12);
244                quarter_words!(w[i], 5, 9, 13, 1);
245                quarter_words!(w[i], 10, 14, 2, 6);
246                quarter_words!(w[i], 15, 3, 7, 11);
247
248                quarter_words!(w[i], 0, 1, 2, 3);
249                quarter_words!(w[i], 5, 6, 7, 4);
250                quarter_words!(w[i], 10, 11, 8, 9);
251                quarter_words!(w[i], 15, 12, 13, 14);
252            }
253        }
254
255        for i in 0..Lanes::USIZE {
256            for j in 0..16 {
257                self.w[i][j] = self.w[i][j].wrapping_add(w[i][j]);
258            }
259        }
260    }
261}
262
263#[cfg(feature = "portable-simd")]
264/// A solution for 1 lane of 128-bit blocks using portable SIMD
265pub struct BlockPortableSimd {
266    a: u32x4,
267    b: u32x4,
268    c: u32x4,
269    d: u32x4,
270}
271
272#[cfg(feature = "portable-simd")]
273#[inline(always)]
274fn simd_rotate_left<const N: usize, const D: u32>(
275    x: core::simd::Simd<u32, N>,
276) -> core::simd::Simd<u32, N>
277where
278    core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
279{
280    let shifted = x << D;
281    let shifted2 = x >> (32 - D);
282    shifted | shifted2
283}
284
285#[cfg(feature = "portable-simd")]
286impl Salsa20 for BlockPortableSimd {
287    type Lanes = U1;
288    type Block = u32x16;
289
290    #[inline(always)]
291    fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
292        let pivoted = Pivot::swizzle(u32x16::from_array(ptr.0));
293
294        #[cfg(target_endian = "big")]
295        let pivoted = pivoted.swap_bytes();
296
297        ptr.0 = *pivoted.as_array();
298    }
299
300    #[inline(always)]
301    fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
302        let pivoted = Inverse::<_, Pivot>::swizzle(u32x16::from_array(ptr.0));
303
304        #[cfg(target_endian = "big")]
305        let pivoted = pivoted.swap_bytes();
306
307        ptr.0 = *pivoted.as_array();
308    }
309
310    #[inline(always)]
311    fn read(ptr: GenericArray<&Self::Block, U1>) -> Self {
312        let a = ptr[0].extract::<0, 4>();
313        let b = ptr[0].extract::<4, 4>();
314        let d = ptr[0].extract::<8, 4>();
315        let c = ptr[0].extract::<12, 4>();
316
317        Self { a, b, c, d }
318    }
319
320    #[inline(always)]
321    fn write(&self, mut ptr: GenericArray<&mut Self::Block, U1>) {
322        use crate::simd::Identity;
323
324        // straighten vectors
325        let ab = Identity::<8>::concat_swizzle(self.a, self.b);
326        let dc = Identity::<8>::concat_swizzle(self.d, self.c);
327        let abdc = Identity::<16>::concat_swizzle(ab, dc);
328
329        *ptr[0] += abdc;
330    }
331
332    #[inline(always)]
333    fn keystream<const ROUND_PAIRS: usize>(&mut self) {
334        if ROUND_PAIRS == 0 {
335            return;
336        }
337
338        for _ in 0..(ROUND_PAIRS * 2) {
339            self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
340            self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
341            self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
342            self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
343
344            self.d = self.d.rotate_elements_left::<1>();
345            self.c = self.c.rotate_elements_left::<2>();
346            self.b = self.b.rotate_elements_left::<3>();
347            (self.b, self.d) = (self.d, self.b);
348        }
349    }
350}
351
352#[cfg(feature = "portable-simd")]
353/// A solution for 2 lanes of 128-bit blocks using portable SIMD
354pub struct BlockPortableSimd2 {
355    a: u32x8,
356    b: u32x8,
357    c: u32x8,
358    d: u32x8,
359}
360
361#[cfg(feature = "portable-simd")]
362impl Salsa20 for BlockPortableSimd2 {
363    type Lanes = U2;
364    type Block = u32x16;
365
366    #[inline(always)]
367    fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
368        BlockPortableSimd::shuffle_in(ptr);
369    }
370
371    #[inline(always)]
372    fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
373        BlockPortableSimd::shuffle_out(ptr);
374    }
375
376    #[inline(always)]
377    fn read(ptr: GenericArray<&Self::Block, U2>) -> Self {
378        let buffer0_ab = core::simd::simd_swizzle!(*ptr[0], [0, 1, 2, 3, 4, 5, 6, 7]);
379        let buffer0_dc = core::simd::simd_swizzle!(*ptr[0], [8, 9, 10, 11, 12, 13, 14, 15]);
380        let buffer1_ab = core::simd::simd_swizzle!(*ptr[1], [0, 1, 2, 3, 4, 5, 6, 7]);
381        let buffer1_dc = core::simd::simd_swizzle!(*ptr[1], [8, 9, 10, 11, 12, 13, 14, 15]);
382
383        let a = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [0, 1, 2, 3, 8, 9, 10, 11]);
384        let b = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [4, 5, 6, 7, 12, 13, 14, 15]);
385        let d = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [0, 1, 2, 3, 8, 9, 10, 11]);
386        let c = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [4, 5, 6, 7, 12, 13, 14, 15]);
387
388        Self { a, b, c, d }
389    }
390
391    #[inline(always)]
392    fn write(&self, mut ptr: GenericArray<&mut Self::Block, U2>) {
393        use crate::simd::Identity;
394
395        // pick out elements from each buffer
396        // this shuffle automatically gets composed by LLVM
397
398        let a0b0 = core::simd::simd_swizzle!(self.a, self.b, [0, 1, 2, 3, 8, 9, 10, 11]);
399        let a1b1 = core::simd::simd_swizzle!(self.a, self.b, [4, 5, 6, 7, 12, 13, 14, 15]);
400        let d0c0 = core::simd::simd_swizzle!(self.d, self.c, [0, 1, 2, 3, 8, 9, 10, 11]);
401        let d1c1 = core::simd::simd_swizzle!(self.d, self.c, [4, 5, 6, 7, 12, 13, 14, 15]);
402
403        *ptr[0] += Identity::<16>::concat_swizzle(a0b0, d0c0);
404        *ptr[1] += Identity::<16>::concat_swizzle(a1b1, d1c1);
405    }
406
407    #[inline(always)]
408    fn keystream<const ROUND_PAIRS: usize>(&mut self) {
409        if ROUND_PAIRS == 0 {
410            return;
411        }
412
413        for _ in 0..(ROUND_PAIRS * 2) {
414            self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
415            self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
416            self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
417            self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
418
419            self.d = core::simd::simd_swizzle!(self.d, [1, 2, 3, 0, 5, 6, 7, 4]);
420            self.c = core::simd::simd_swizzle!(self.c, [2, 3, 0, 1, 6, 7, 4, 5]);
421            self.b = core::simd::simd_swizzle!(self.b, [3, 0, 1, 2, 7, 4, 5, 6]);
422            (self.b, self.d) = (self.d, self.b);
423        }
424    }
425}
426
427#[cfg(test)]
428#[allow(unused_imports)]
429mod tests {
430    use generic_array::{GenericArray, typenum::U1};
431
432    use super::*;
433
434    pub(crate) fn test_shuffle_in_out_identity<S: Salsa20>()
435    where
436        S::Block: BlockType,
437    {
438        fn lfsr(x: &mut u32) -> u32 {
439            *x = *x ^ (*x >> 2);
440            *x = *x ^ (*x >> 3);
441            *x = *x ^ (*x >> 5);
442            *x
443        }
444
445        let mut state = 0;
446
447        for _ in 0..5 {
448            let test_input = Align64(core::array::from_fn(|i| lfsr(&mut state) + i as u32));
449
450            let mut result = test_input.clone();
451            S::shuffle_in(&mut result);
452            S::shuffle_out(&mut result);
453            assert_eq!(result, test_input);
454        }
455    }
456
457    #[cfg(feature = "portable-simd")]
458    fn test_keystream_portable_simd<const ROUND_PAIRS: usize>() {
459        test_shuffle_in_out_identity::<BlockPortableSimd>();
460
461        let test_input: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
462        let mut expected = test_input.clone();
463
464        let mut test_input_scalar_shuffled = test_input.clone();
465        BlockScalar::<U1>::shuffle_in(&mut test_input_scalar_shuffled);
466        let mut block =
467            BlockScalar::<U1>::read(GenericArray::from_array([&test_input_scalar_shuffled]));
468        block.keystream::<ROUND_PAIRS>();
469        block.write(GenericArray::from_array([&mut expected]));
470        BlockScalar::<U1>::shuffle_out(&mut expected);
471
472        let mut test_input_shuffled = test_input.clone();
473
474        BlockPortableSimd::shuffle_in(&mut test_input_shuffled);
475        let mut result = u32x16::from_array(*test_input_shuffled);
476
477        let mut block_v = BlockPortableSimd::read(GenericArray::from_array([&result]));
478        block_v.keystream::<ROUND_PAIRS>();
479        block_v.write(GenericArray::from_array([&mut result]));
480
481        let mut output = Align64(result.to_array());
482        BlockPortableSimd::shuffle_out(&mut output);
483
484        assert_eq!(output, expected);
485    }
486
487    #[cfg(feature = "portable-simd")]
488    fn test_keystream_portable_simd2<const ROUND_PAIRS: usize>() {
489        test_shuffle_in_out_identity::<BlockPortableSimd2>();
490
491        let test_input0: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
492        let test_input1: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32 + 16));
493        let mut expected0 = test_input0.clone();
494        let mut expected1 = test_input1.clone();
495
496        let mut test_input0_scalar_shuffled = test_input0.clone();
497        let mut test_input1_scalar_shuffled = test_input1.clone();
498        BlockScalar::<U1>::shuffle_in(&mut test_input0_scalar_shuffled);
499        BlockScalar::<U1>::shuffle_in(&mut test_input1_scalar_shuffled);
500
501        let mut block0 =
502            BlockScalar::<U1>::read(GenericArray::from_array([&test_input0_scalar_shuffled]));
503        let mut block1 =
504            BlockScalar::<U1>::read(GenericArray::from_array([&test_input1_scalar_shuffled]));
505        block0.keystream::<ROUND_PAIRS>();
506        block1.keystream::<ROUND_PAIRS>();
507        block0.write(GenericArray::from_array([&mut expected0]));
508        block1.write(GenericArray::from_array([&mut expected1]));
509        BlockScalar::<U1>::shuffle_out(&mut expected0);
510        BlockScalar::<U1>::shuffle_out(&mut expected1);
511
512        let mut test_input0_shuffled = test_input0.clone();
513        let mut test_input1_shuffled = test_input1.clone();
514        BlockPortableSimd2::shuffle_in(&mut test_input0_shuffled);
515        BlockPortableSimd2::shuffle_in(&mut test_input1_shuffled);
516
517        let mut result0 = u32x16::from_array(*test_input0_shuffled);
518        let mut result1 = u32x16::from_array(*test_input1_shuffled);
519
520        let mut block_v0 = BlockPortableSimd2::read(GenericArray::from_array([&result0, &result1]));
521        block_v0.keystream::<ROUND_PAIRS>();
522        block_v0.write(GenericArray::from_array([&mut result0, &mut result1]));
523
524        let mut output0 = Align64(result0.to_array());
525        let mut output1 = Align64(result1.to_array());
526
527        BlockPortableSimd2::shuffle_out(&mut output0);
528        BlockPortableSimd2::shuffle_out(&mut output1);
529
530        assert_eq!(output0, expected0);
531        assert_eq!(output1, expected1);
532    }
533
534    #[cfg(feature = "portable-simd")]
535    #[test]
536    fn test_keystream_portable_simd_0() {
537        test_keystream_portable_simd::<0>();
538    }
539
540    #[cfg(feature = "portable-simd")]
541    #[test]
542    fn test_keystream_portable_simd_2() {
543        test_keystream_portable_simd::<1>();
544    }
545
546    #[cfg(feature = "portable-simd")]
547    #[test]
548    fn test_keystream_portable_simd_8() {
549        test_keystream_portable_simd::<4>();
550    }
551
552    #[cfg(feature = "portable-simd")]
553    #[test]
554    fn test_keystream_portable_simd_10() {
555        test_keystream_portable_simd::<5>();
556    }
557
558    #[cfg(feature = "portable-simd")]
559    #[test]
560    fn test_keystream_portable_simd2_0() {
561        test_keystream_portable_simd2::<0>();
562    }
563
564    #[cfg(feature = "portable-simd")]
565    #[test]
566    fn test_keystream_portable_simd2_2() {
567        test_keystream_portable_simd2::<1>();
568    }
569
570    #[cfg(feature = "portable-simd")]
571    #[test]
572    fn test_keystream_portable_simd2_8() {
573        test_keystream_portable_simd2::<4>();
574    }
575
576    #[cfg(feature = "portable-simd")]
577    #[test]
578    fn test_keystream_portable_simd2_10() {
579        test_keystream_portable_simd2::<5>();
580    }
581}