1#![allow(
2 unused,
3 reason = "APIs that allow switching cores in code are not exposed to the public API, yet"
4)]
5
6#[cfg(target_arch = "x86_64")]
7pub(crate) mod x86_64;
8
9use generic_array::{
10 ArrayLength, GenericArray,
11 sequence::GenericSequence,
12 typenum::{IsLessOrEqual, U1, U2},
13};
14
15#[cfg(feature = "portable-simd")]
16#[allow(unused_imports)]
17use core::simd::{Swizzle as _, num::SimdUint, u32x4, u32x8, u32x16};
18
19#[allow(
20 unused_imports,
21 reason = "rust-analyzer doesn't consider -Ctarget-feature, silencing warnings"
22)]
23use crate::{
24 Align64,
25 simd::{Compose, ConcatLo, ExtractU32x2, FlipTable16, Inverse, Swizzle},
26};
27
28macro_rules! quarter_words {
29 ($w:expr, $a:literal, $b:literal, $c:literal, $d:literal) => {
30 $w[$b] ^= $w[$a].wrapping_add($w[$d]).rotate_left(7);
31 $w[$c] ^= $w[$b].wrapping_add($w[$a]).rotate_left(9);
32 $w[$d] ^= $w[$c].wrapping_add($w[$b]).rotate_left(13);
33 $w[$a] ^= $w[$d].wrapping_add($w[$c]).rotate_left(18);
34 };
35}
36
37struct Pivot;
39
40impl Swizzle<16> for Pivot {
41 const INDEX: [usize; 16] = [0, 5, 10, 15, 4, 9, 14, 3, 12, 1, 6, 11, 8, 13, 2, 7];
42}
43
44#[allow(unused, reason = "rust-analyzer spam, actually used")]
46struct RoundShuffleAbdc;
47
48impl Swizzle<16> for RoundShuffleAbdc {
49 const INDEX: [usize; 16] = const {
50 let mut index = [0; 16];
51 let mut i = 0;
52 while i < 4 {
53 index[i] = i;
54 i += 1;
55 }
56 while i < 8 {
57 index[i] = 8 + (i + 1) % 4;
58 i += 1;
59 }
60 while i < 12 {
61 index[i] = 4 + (i + 3) % 4;
62 i += 1;
63 }
64 while i < 16 {
65 index[i] = 12 + (i + 2) % 4;
66 i += 1;
67 }
68 index
69 };
70}
71
72#[cfg(feature = "portable-simd")]
73impl core::simd::Swizzle<16> for Pivot {
74 const INDEX: [usize; 16] = <Self as Swizzle<16>>::INDEX;
75}
76
77pub trait BlockType: Clone + Copy {
79 unsafe fn read_from_ptr(ptr: *const Self) -> Self;
81 unsafe fn write_to_ptr(self, ptr: *mut Self);
83 fn xor_with(&mut self, other: Self);
85}
86
87#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
88impl BlockType for core::arch::x86_64::__m512i {
89 #[inline(always)]
90 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
91 unsafe { core::ptr::read(ptr.cast()) }
92 }
93 #[inline(always)]
94 unsafe fn write_to_ptr(self, ptr: *mut Self) {
95 unsafe { core::ptr::write(ptr.cast(), self) }
96 }
97 #[inline(always)]
98 fn xor_with(&mut self, other: Self) {
99 use core::arch::x86_64::*;
100 unsafe {
101 *self = _mm512_xor_si512(*self, other);
102 }
103 }
104}
105
106#[cfg(target_arch = "x86_64")]
107impl BlockType for [core::arch::x86_64::__m256i; 2] {
108 #[inline(always)]
109 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
110 unsafe { core::ptr::read(ptr) }
111 }
112 #[inline(always)]
113 unsafe fn write_to_ptr(self, ptr: *mut Self) {
114 unsafe { core::ptr::write(ptr, self) };
115 }
116 #[inline(always)]
117 fn xor_with(&mut self, other: Self) {
118 use core::arch::x86_64::*;
119 unsafe {
120 self[0] = _mm256_xor_si256(self[0], other[0]);
121 self[1] = _mm256_xor_si256(self[1], other[1]);
122 }
123 }
124}
125
126#[cfg(target_arch = "x86_64")]
127impl BlockType for [core::arch::x86_64::__m128i; 4] {
128 #[inline(always)]
129 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
130 unsafe { core::ptr::read(ptr) }
131 }
132 #[inline(always)]
133 unsafe fn write_to_ptr(self, ptr: *mut Self) {
134 unsafe { core::ptr::write(ptr, self) };
135 }
136 #[inline(always)]
137 fn xor_with(&mut self, other: Self) {
138 use core::arch::x86_64::*;
139 unsafe {
140 self[0] = _mm_xor_si128(self[0], other[0]);
141 self[1] = _mm_xor_si128(self[1], other[1]);
142 self[2] = _mm_xor_si128(self[2], other[2]);
143 self[3] = _mm_xor_si128(self[3], other[3]);
144 }
145 }
146}
147
148impl BlockType for Align64<[u32; 16]> {
149 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
150 unsafe { ptr.read() }
151 }
152 unsafe fn write_to_ptr(mut self, ptr: *mut Self) {
153 unsafe { ptr.write(self) }
154 }
155 fn xor_with(&mut self, other: Self) {
156 for i in 0..16 {
157 self.0[i] ^= other.0[i];
158 }
159 }
160}
161
162#[cfg(feature = "portable-simd")]
163impl BlockType for core::simd::u32x16 {
164 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
165 unsafe { ptr.read() }
166 }
167 unsafe fn write_to_ptr(self, ptr: *mut Self) {
168 unsafe { ptr.write(self) }
169 }
170 fn xor_with(&mut self, other: Self) {
171 *self ^= other;
172 }
173}
174
175pub trait Salsa20 {
177 type Lanes: ArrayLength;
179 type Block: BlockType;
181
182 fn shuffle_in(_ptr: &mut Align64<[u32; 16]>) {}
184
185 fn shuffle_out(_ptr: &mut Align64<[u32; 16]>) {}
187
188 fn read(ptr: GenericArray<&Self::Block, Self::Lanes>) -> Self;
190 fn write(&self, ptr: GenericArray<&mut Self::Block, Self::Lanes>);
194 fn keystream<const ROUND_PAIRS: usize>(&mut self);
196}
197
198#[allow(unused, reason = "Currently unused, but handy for testing")]
200pub struct BlockScalar<Lanes: ArrayLength> {
201 w: GenericArray<[u32; 16], Lanes>,
202}
203
204impl<Lanes: ArrayLength> Salsa20 for BlockScalar<Lanes> {
205 type Lanes = Lanes;
206 type Block = Align64<[u32; 16]>;
207
208 #[cfg(target_endian = "big")]
209 fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
210 for i in 0..16 {
211 ptr.0[i] = ptr.0[i].swap_bytes();
212 }
213 }
214
215 #[cfg(target_endian = "big")]
216 fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
217 for i in 0..16 {
218 ptr.0[i] = ptr.0[i].swap_bytes();
219 }
220 }
221
222 #[inline(always)]
223 fn read(ptr: GenericArray<&Self::Block, Lanes>) -> Self {
224 Self {
225 w: GenericArray::generate(|i| **ptr[i]),
226 }
227 }
228
229 #[inline(always)]
230 fn write(&self, mut ptr: GenericArray<&mut Self::Block, Lanes>) {
231 for i in 0..Lanes::USIZE {
232 for j in 0..16 {
233 ptr[i][j] = self.w[i][j];
234 }
235 }
236 }
237
238 fn keystream<const ROUND_PAIRS: usize>(&mut self) {
239 let mut w = self.w.clone();
240
241 for _ in 0..ROUND_PAIRS {
242 for i in 0..Lanes::USIZE {
243 quarter_words!(w[i], 0, 4, 8, 12);
244 quarter_words!(w[i], 5, 9, 13, 1);
245 quarter_words!(w[i], 10, 14, 2, 6);
246 quarter_words!(w[i], 15, 3, 7, 11);
247
248 quarter_words!(w[i], 0, 1, 2, 3);
249 quarter_words!(w[i], 5, 6, 7, 4);
250 quarter_words!(w[i], 10, 11, 8, 9);
251 quarter_words!(w[i], 15, 12, 13, 14);
252 }
253 }
254
255 for i in 0..Lanes::USIZE {
256 for j in 0..16 {
257 self.w[i][j] = self.w[i][j].wrapping_add(w[i][j]);
258 }
259 }
260 }
261}
262
263#[cfg(feature = "portable-simd")]
264pub struct BlockPortableSimd {
266 a: u32x4,
267 b: u32x4,
268 c: u32x4,
269 d: u32x4,
270}
271
272#[cfg(feature = "portable-simd")]
273#[inline(always)]
274fn simd_rotate_left<const N: usize, const D: u32>(
275 x: core::simd::Simd<u32, N>,
276) -> core::simd::Simd<u32, N>
277where
278 core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
279{
280 let shifted = x << D;
281 let shifted2 = x >> (32 - D);
282 shifted | shifted2
283}
284
285#[cfg(feature = "portable-simd")]
286impl Salsa20 for BlockPortableSimd {
287 type Lanes = U1;
288 type Block = u32x16;
289
290 #[inline(always)]
291 fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
292 let pivoted = Pivot::swizzle(u32x16::from_array(ptr.0));
293
294 #[cfg(target_endian = "big")]
295 let pivoted = pivoted.swap_bytes();
296
297 ptr.0 = *pivoted.as_array();
298 }
299
300 #[inline(always)]
301 fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
302 let pivoted = Inverse::<_, Pivot>::swizzle(u32x16::from_array(ptr.0));
303
304 #[cfg(target_endian = "big")]
305 let pivoted = pivoted.swap_bytes();
306
307 ptr.0 = *pivoted.as_array();
308 }
309
310 #[inline(always)]
311 fn read(ptr: GenericArray<&Self::Block, U1>) -> Self {
312 let a = ptr[0].extract::<0, 4>();
313 let b = ptr[0].extract::<4, 4>();
314 let d = ptr[0].extract::<8, 4>();
315 let c = ptr[0].extract::<12, 4>();
316
317 Self { a, b, c, d }
318 }
319
320 #[inline(always)]
321 fn write(&self, mut ptr: GenericArray<&mut Self::Block, U1>) {
322 use crate::simd::Identity;
323
324 let ab = Identity::<8>::concat_swizzle(self.a, self.b);
326 let dc = Identity::<8>::concat_swizzle(self.d, self.c);
327 let abdc = Identity::<16>::concat_swizzle(ab, dc);
328
329 *ptr[0] += abdc;
330 }
331
332 #[inline(always)]
333 fn keystream<const ROUND_PAIRS: usize>(&mut self) {
334 if ROUND_PAIRS == 0 {
335 return;
336 }
337
338 for _ in 0..(ROUND_PAIRS * 2) {
339 self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
340 self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
341 self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
342 self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
343
344 self.d = self.d.rotate_elements_left::<1>();
345 self.c = self.c.rotate_elements_left::<2>();
346 self.b = self.b.rotate_elements_left::<3>();
347 (self.b, self.d) = (self.d, self.b);
348 }
349 }
350}
351
352#[cfg(feature = "portable-simd")]
353pub struct BlockPortableSimd2 {
355 a: u32x8,
356 b: u32x8,
357 c: u32x8,
358 d: u32x8,
359}
360
361#[cfg(feature = "portable-simd")]
362impl Salsa20 for BlockPortableSimd2 {
363 type Lanes = U2;
364 type Block = u32x16;
365
366 #[inline(always)]
367 fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
368 BlockPortableSimd::shuffle_in(ptr);
369 }
370
371 #[inline(always)]
372 fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
373 BlockPortableSimd::shuffle_out(ptr);
374 }
375
376 #[inline(always)]
377 fn read(ptr: GenericArray<&Self::Block, U2>) -> Self {
378 let buffer0_ab = core::simd::simd_swizzle!(*ptr[0], [0, 1, 2, 3, 4, 5, 6, 7]);
379 let buffer0_dc = core::simd::simd_swizzle!(*ptr[0], [8, 9, 10, 11, 12, 13, 14, 15]);
380 let buffer1_ab = core::simd::simd_swizzle!(*ptr[1], [0, 1, 2, 3, 4, 5, 6, 7]);
381 let buffer1_dc = core::simd::simd_swizzle!(*ptr[1], [8, 9, 10, 11, 12, 13, 14, 15]);
382
383 let a = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [0, 1, 2, 3, 8, 9, 10, 11]);
384 let b = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [4, 5, 6, 7, 12, 13, 14, 15]);
385 let d = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [0, 1, 2, 3, 8, 9, 10, 11]);
386 let c = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [4, 5, 6, 7, 12, 13, 14, 15]);
387
388 Self { a, b, c, d }
389 }
390
391 #[inline(always)]
392 fn write(&self, mut ptr: GenericArray<&mut Self::Block, U2>) {
393 use crate::simd::Identity;
394
395 let a0b0 = core::simd::simd_swizzle!(self.a, self.b, [0, 1, 2, 3, 8, 9, 10, 11]);
399 let a1b1 = core::simd::simd_swizzle!(self.a, self.b, [4, 5, 6, 7, 12, 13, 14, 15]);
400 let d0c0 = core::simd::simd_swizzle!(self.d, self.c, [0, 1, 2, 3, 8, 9, 10, 11]);
401 let d1c1 = core::simd::simd_swizzle!(self.d, self.c, [4, 5, 6, 7, 12, 13, 14, 15]);
402
403 *ptr[0] += Identity::<16>::concat_swizzle(a0b0, d0c0);
404 *ptr[1] += Identity::<16>::concat_swizzle(a1b1, d1c1);
405 }
406
407 #[inline(always)]
408 fn keystream<const ROUND_PAIRS: usize>(&mut self) {
409 if ROUND_PAIRS == 0 {
410 return;
411 }
412
413 for _ in 0..(ROUND_PAIRS * 2) {
414 self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
415 self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
416 self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
417 self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
418
419 self.d = core::simd::simd_swizzle!(self.d, [1, 2, 3, 0, 5, 6, 7, 4]);
420 self.c = core::simd::simd_swizzle!(self.c, [2, 3, 0, 1, 6, 7, 4, 5]);
421 self.b = core::simd::simd_swizzle!(self.b, [3, 0, 1, 2, 7, 4, 5, 6]);
422 (self.b, self.d) = (self.d, self.b);
423 }
424 }
425}
426
427#[cfg(test)]
428#[allow(unused_imports)]
429mod tests {
430 use generic_array::{GenericArray, typenum::U1};
431
432 use super::*;
433
434 pub(crate) fn test_shuffle_in_out_identity<S: Salsa20>()
435 where
436 S::Block: BlockType,
437 {
438 fn lfsr(x: &mut u32) -> u32 {
439 *x = *x ^ (*x >> 2);
440 *x = *x ^ (*x >> 3);
441 *x = *x ^ (*x >> 5);
442 *x
443 }
444
445 let mut state = 0;
446
447 for _ in 0..5 {
448 let test_input = Align64(core::array::from_fn(|i| lfsr(&mut state) + i as u32));
449
450 let mut result = test_input.clone();
451 S::shuffle_in(&mut result);
452 S::shuffle_out(&mut result);
453 assert_eq!(result, test_input);
454 }
455 }
456
457 #[cfg(feature = "portable-simd")]
458 fn test_keystream_portable_simd<const ROUND_PAIRS: usize>() {
459 test_shuffle_in_out_identity::<BlockPortableSimd>();
460
461 let test_input: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
462 let mut expected = test_input.clone();
463
464 let mut test_input_scalar_shuffled = test_input.clone();
465 BlockScalar::<U1>::shuffle_in(&mut test_input_scalar_shuffled);
466 let mut block =
467 BlockScalar::<U1>::read(GenericArray::from_array([&test_input_scalar_shuffled]));
468 block.keystream::<ROUND_PAIRS>();
469 block.write(GenericArray::from_array([&mut expected]));
470 BlockScalar::<U1>::shuffle_out(&mut expected);
471
472 let mut test_input_shuffled = test_input.clone();
473
474 BlockPortableSimd::shuffle_in(&mut test_input_shuffled);
475 let mut result = u32x16::from_array(*test_input_shuffled);
476
477 let mut block_v = BlockPortableSimd::read(GenericArray::from_array([&result]));
478 block_v.keystream::<ROUND_PAIRS>();
479 block_v.write(GenericArray::from_array([&mut result]));
480
481 let mut output = Align64(result.to_array());
482 BlockPortableSimd::shuffle_out(&mut output);
483
484 assert_eq!(output, expected);
485 }
486
487 #[cfg(feature = "portable-simd")]
488 fn test_keystream_portable_simd2<const ROUND_PAIRS: usize>() {
489 test_shuffle_in_out_identity::<BlockPortableSimd2>();
490
491 let test_input0: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
492 let test_input1: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32 + 16));
493 let mut expected0 = test_input0.clone();
494 let mut expected1 = test_input1.clone();
495
496 let mut test_input0_scalar_shuffled = test_input0.clone();
497 let mut test_input1_scalar_shuffled = test_input1.clone();
498 BlockScalar::<U1>::shuffle_in(&mut test_input0_scalar_shuffled);
499 BlockScalar::<U1>::shuffle_in(&mut test_input1_scalar_shuffled);
500
501 let mut block0 =
502 BlockScalar::<U1>::read(GenericArray::from_array([&test_input0_scalar_shuffled]));
503 let mut block1 =
504 BlockScalar::<U1>::read(GenericArray::from_array([&test_input1_scalar_shuffled]));
505 block0.keystream::<ROUND_PAIRS>();
506 block1.keystream::<ROUND_PAIRS>();
507 block0.write(GenericArray::from_array([&mut expected0]));
508 block1.write(GenericArray::from_array([&mut expected1]));
509 BlockScalar::<U1>::shuffle_out(&mut expected0);
510 BlockScalar::<U1>::shuffle_out(&mut expected1);
511
512 let mut test_input0_shuffled = test_input0.clone();
513 let mut test_input1_shuffled = test_input1.clone();
514 BlockPortableSimd2::shuffle_in(&mut test_input0_shuffled);
515 BlockPortableSimd2::shuffle_in(&mut test_input1_shuffled);
516
517 let mut result0 = u32x16::from_array(*test_input0_shuffled);
518 let mut result1 = u32x16::from_array(*test_input1_shuffled);
519
520 let mut block_v0 = BlockPortableSimd2::read(GenericArray::from_array([&result0, &result1]));
521 block_v0.keystream::<ROUND_PAIRS>();
522 block_v0.write(GenericArray::from_array([&mut result0, &mut result1]));
523
524 let mut output0 = Align64(result0.to_array());
525 let mut output1 = Align64(result1.to_array());
526
527 BlockPortableSimd2::shuffle_out(&mut output0);
528 BlockPortableSimd2::shuffle_out(&mut output1);
529
530 assert_eq!(output0, expected0);
531 assert_eq!(output1, expected1);
532 }
533
534 #[cfg(feature = "portable-simd")]
535 #[test]
536 fn test_keystream_portable_simd_0() {
537 test_keystream_portable_simd::<0>();
538 }
539
540 #[cfg(feature = "portable-simd")]
541 #[test]
542 fn test_keystream_portable_simd_2() {
543 test_keystream_portable_simd::<1>();
544 }
545
546 #[cfg(feature = "portable-simd")]
547 #[test]
548 fn test_keystream_portable_simd_8() {
549 test_keystream_portable_simd::<4>();
550 }
551
552 #[cfg(feature = "portable-simd")]
553 #[test]
554 fn test_keystream_portable_simd_10() {
555 test_keystream_portable_simd::<5>();
556 }
557
558 #[cfg(feature = "portable-simd")]
559 #[test]
560 fn test_keystream_portable_simd2_0() {
561 test_keystream_portable_simd2::<0>();
562 }
563
564 #[cfg(feature = "portable-simd")]
565 #[test]
566 fn test_keystream_portable_simd2_2() {
567 test_keystream_portable_simd2::<1>();
568 }
569
570 #[cfg(feature = "portable-simd")]
571 #[test]
572 fn test_keystream_portable_simd2_8() {
573 test_keystream_portable_simd2::<4>();
574 }
575
576 #[cfg(feature = "portable-simd")]
577 #[test]
578 fn test_keystream_portable_simd2_10() {
579 test_keystream_portable_simd2::<5>();
580 }
581}