1#![allow(
2 unused,
3 reason = "APIs that allow switching cores in code are not exposed to the public API, yet"
4)]
5
6#[cfg(target_arch = "x86_64")]
7pub(crate) mod x86_64;
8
9use generic_array::{
10 ArrayLength, GenericArray,
11 sequence::GenericSequence,
12 typenum::{U1, U2},
13};
14
15#[cfg(feature = "portable-simd")]
16#[allow(unused_imports)]
17use core::simd::{Swizzle as _, num::SimdUint, u32x4, u32x8, u32x16};
18
19#[allow(
20 unused_imports,
21 reason = "rust-analyzer doesn't consider -Ctarget-feature, silencing warnings"
22)]
23use crate::{
24 Align64,
25 simd::{Compose, ConcatLo, ExtractU32x2, FlipTable16, Inverse, Swizzle},
26};
27
28macro_rules! quarter_words {
29 ($w:expr, $a:literal, $b:literal, $c:literal, $d:literal) => {
30 $w[$b] ^= $w[$a].wrapping_add($w[$d]).rotate_left(7);
31 $w[$c] ^= $w[$b].wrapping_add($w[$a]).rotate_left(9);
32 $w[$d] ^= $w[$c].wrapping_add($w[$b]).rotate_left(13);
33 $w[$a] ^= $w[$d].wrapping_add($w[$c]).rotate_left(18);
34 };
35}
36
37struct Pivot;
39
40impl Swizzle<16> for Pivot {
41 const INDEX: [usize; 16] = [0, 5, 10, 15, 4, 9, 14, 3, 12, 1, 6, 11, 8, 13, 2, 7];
42}
43
44#[allow(unused, reason = "rust-analyzer spam, actually used")]
46struct RoundShuffleAbdc;
47
48impl Swizzle<16> for RoundShuffleAbdc {
49 const INDEX: [usize; 16] = const {
50 let mut index = [0; 16];
51 let mut i = 0;
52 while i < 4 {
53 index[i] = i;
54 i += 1;
55 }
56 while i < 8 {
57 index[i] = 8 + (i + 1) % 4;
58 i += 1;
59 }
60 while i < 12 {
61 index[i] = 4 + (i + 3) % 4;
62 i += 1;
63 }
64 while i < 16 {
65 index[i] = 12 + (i + 2) % 4;
66 i += 1;
67 }
68 index
69 };
70}
71
72#[cfg(feature = "portable-simd")]
73impl core::simd::Swizzle<16> for Pivot {
74 const INDEX: [usize; 16] = <Self as Swizzle<16>>::INDEX;
75}
76
77pub trait BlockType: Clone + Copy {
79 unsafe fn read_from_ptr(ptr: *const Self) -> Self;
81 unsafe fn write_to_ptr(self, ptr: *mut Self);
83 fn xor_with(&mut self, other: Self);
85}
86
87#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
88impl BlockType for core::arch::x86_64::__m512i {
89 #[inline(always)]
90 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
91 unsafe { core::ptr::read(ptr.cast()) }
92 }
93 #[inline(always)]
94 unsafe fn write_to_ptr(self, ptr: *mut Self) {
95 unsafe { core::ptr::write(ptr.cast(), self) }
96 }
97 #[inline(always)]
98 fn xor_with(&mut self, other: Self) {
99 use core::arch::x86_64::*;
100 unsafe {
101 *self = _mm512_xor_si512(*self, other);
102 }
103 }
104}
105
106#[cfg(target_arch = "x86_64")]
107impl BlockType for [core::arch::x86_64::__m256i; 2] {
108 #[inline(always)]
109 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
110 unsafe { core::ptr::read(ptr) }
111 }
112 #[inline(always)]
113 unsafe fn write_to_ptr(self, ptr: *mut Self) {
114 unsafe { core::ptr::write(ptr, self) };
115 }
116 #[inline(always)]
117 fn xor_with(&mut self, other: Self) {
118 use core::arch::x86_64::*;
119 unsafe {
120 self[0] = _mm256_xor_si256(self[0], other[0]);
121 self[1] = _mm256_xor_si256(self[1], other[1]);
122 }
123 }
124}
125
126impl BlockType for Align64<[u32; 16]> {
127 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
128 unsafe { ptr.read() }
129 }
130 unsafe fn write_to_ptr(mut self, ptr: *mut Self) {
131 unsafe { ptr.write(self) }
132 }
133 fn xor_with(&mut self, other: Self) {
134 for i in 0..16 {
135 self.0[i] ^= other.0[i];
136 }
137 }
138}
139
140#[cfg(feature = "portable-simd")]
141impl BlockType for core::simd::u32x16 {
142 unsafe fn read_from_ptr(ptr: *const Self) -> Self {
143 unsafe { ptr.read() }
144 }
145 unsafe fn write_to_ptr(self, ptr: *mut Self) {
146 unsafe { ptr.write(self) }
147 }
148 fn xor_with(&mut self, other: Self) {
149 *self ^= other;
150 }
151}
152
153pub trait Salsa20 {
155 type Lanes: ArrayLength;
157 type Block: BlockType;
159
160 fn shuffle_in(_ptr: &mut Align64<[u32; 16]>) {}
162
163 fn shuffle_out(_ptr: &mut Align64<[u32; 16]>) {}
165
166 fn read(ptr: GenericArray<&Self::Block, Self::Lanes>) -> Self;
168 fn write(&self, ptr: GenericArray<&mut Self::Block, Self::Lanes>);
172 fn keystream<const ROUND_PAIRS: usize>(&mut self);
174}
175
176#[allow(unused, reason = "Currently unused, but handy for testing")]
178pub struct BlockScalar<Lanes: ArrayLength> {
179 w: GenericArray<[u32; 16], Lanes>,
180}
181
182impl<Lanes: ArrayLength> Salsa20 for BlockScalar<Lanes> {
183 type Lanes = Lanes;
184 type Block = Align64<[u32; 16]>;
185
186 #[cfg(target_endian = "big")]
187 fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
188 for i in 0..16 {
189 ptr.0[i] = ptr.0[i].swap_bytes();
190 }
191 }
192
193 #[cfg(target_endian = "big")]
194 fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
195 for i in 0..16 {
196 ptr.0[i] = ptr.0[i].swap_bytes();
197 }
198 }
199
200 #[inline(always)]
201 fn read(ptr: GenericArray<&Self::Block, Lanes>) -> Self {
202 Self {
203 w: GenericArray::generate(|i| **ptr[i]),
204 }
205 }
206
207 #[inline(always)]
208 fn write(&self, mut ptr: GenericArray<&mut Self::Block, Lanes>) {
209 for i in 0..Lanes::USIZE {
210 for j in 0..16 {
211 ptr[i][j] = self.w[i][j];
212 }
213 }
214 }
215
216 #[inline(always)]
217 fn keystream<const ROUND_PAIRS: usize>(&mut self) {
218 let mut w = self.w.clone();
219
220 for i in 0..Lanes::USIZE {
221 for _ in 0..ROUND_PAIRS {
222 quarter_words!(w[i], 0, 4, 8, 12);
223 quarter_words!(w[i], 5, 9, 13, 1);
224 quarter_words!(w[i], 10, 14, 2, 6);
225 quarter_words!(w[i], 15, 3, 7, 11);
226
227 quarter_words!(w[i], 0, 1, 2, 3);
228 quarter_words!(w[i], 5, 6, 7, 4);
229 quarter_words!(w[i], 10, 11, 8, 9);
230 quarter_words!(w[i], 15, 12, 13, 14);
231 }
232 }
233
234 for i in 0..Lanes::USIZE {
235 for j in 0..16 {
236 self.w[i][j] = self.w[i][j].wrapping_add(w[i][j]);
237 }
238 }
239 }
240}
241
242#[cfg(feature = "portable-simd")]
243pub struct BlockPortableSimd {
245 a: u32x4,
246 b: u32x4,
247 c: u32x4,
248 d: u32x4,
249}
250
251#[cfg(feature = "portable-simd")]
252#[inline(always)]
253fn simd_rotate_left<const N: usize, const D: u32>(
254 x: core::simd::Simd<u32, N>,
255) -> core::simd::Simd<u32, N>
256where
257 core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
258{
259 let shifted = x << D;
260 let shifted2 = x >> (32 - D);
261 shifted | shifted2
262}
263
264#[cfg(feature = "portable-simd")]
265impl Salsa20 for BlockPortableSimd {
266 type Lanes = U1;
267 type Block = u32x16;
268
269 #[inline(always)]
270 fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
271 let pivoted = Pivot::swizzle(u32x16::from_array(ptr.0));
272
273 #[cfg(target_endian = "big")]
274 let pivoted = pivoted.swap_bytes();
275
276 ptr.0 = *pivoted.as_array();
277 }
278
279 #[inline(always)]
280 fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
281 let pivoted = Inverse::<_, Pivot>::swizzle(u32x16::from_array(ptr.0));
282
283 #[cfg(target_endian = "big")]
284 let pivoted = pivoted.swap_bytes();
285
286 ptr.0 = *pivoted.as_array();
287 }
288
289 #[inline(always)]
290 fn read(ptr: GenericArray<&Self::Block, U1>) -> Self {
291 let a = ptr[0].extract::<0, 4>();
292 let b = ptr[0].extract::<4, 4>();
293 let d = ptr[0].extract::<8, 4>();
294 let c = ptr[0].extract::<12, 4>();
295
296 Self { a, b, c, d }
297 }
298
299 #[inline(always)]
300 fn write(&self, mut ptr: GenericArray<&mut Self::Block, U1>) {
301 use crate::simd::Identity;
302
303 let ab = Identity::<8>::concat_swizzle(self.a, self.b);
305 let dc = Identity::<8>::concat_swizzle(self.d, self.c);
306 let abdc = Identity::<16>::concat_swizzle(ab, dc);
307
308 *ptr[0] += abdc;
309 }
310
311 #[inline(always)]
312 fn keystream<const ROUND_PAIRS: usize>(&mut self) {
313 if ROUND_PAIRS == 0 {
314 return;
315 }
316
317 for _ in 0..(ROUND_PAIRS * 2) {
318 self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
319 self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
320 self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
321 self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
322
323 self.d = self.d.rotate_elements_left::<1>();
324 self.c = self.c.rotate_elements_left::<2>();
325 self.b = self.b.rotate_elements_left::<3>();
326 (self.b, self.d) = (self.d, self.b);
327 }
328 }
329}
330
331#[cfg(feature = "portable-simd")]
332pub struct BlockPortableSimd2 {
334 a: u32x8,
335 b: u32x8,
336 c: u32x8,
337 d: u32x8,
338}
339
340#[cfg(feature = "portable-simd")]
341impl Salsa20 for BlockPortableSimd2 {
342 type Lanes = U2;
343 type Block = u32x16;
344
345 #[inline(always)]
346 fn shuffle_in(ptr: &mut Align64<[u32; 16]>) {
347 BlockPortableSimd::shuffle_in(ptr);
348 }
349
350 #[inline(always)]
351 fn shuffle_out(ptr: &mut Align64<[u32; 16]>) {
352 BlockPortableSimd::shuffle_out(ptr);
353 }
354
355 #[inline(always)]
356 fn read(ptr: GenericArray<&Self::Block, U2>) -> Self {
357 let buffer0_ab = core::simd::simd_swizzle!(*ptr[0], [0, 1, 2, 3, 4, 5, 6, 7]);
358 let buffer0_dc = core::simd::simd_swizzle!(*ptr[0], [8, 9, 10, 11, 12, 13, 14, 15]);
359 let buffer1_ab = core::simd::simd_swizzle!(*ptr[1], [0, 1, 2, 3, 4, 5, 6, 7]);
360 let buffer1_dc = core::simd::simd_swizzle!(*ptr[1], [8, 9, 10, 11, 12, 13, 14, 15]);
361
362 let a = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [0, 1, 2, 3, 8, 9, 10, 11]);
363 let b = core::simd::simd_swizzle!(buffer0_ab, buffer1_ab, [4, 5, 6, 7, 12, 13, 14, 15]);
364 let d = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [0, 1, 2, 3, 8, 9, 10, 11]);
365 let c = core::simd::simd_swizzle!(buffer0_dc, buffer1_dc, [4, 5, 6, 7, 12, 13, 14, 15]);
366
367 Self { a, b, c, d }
368 }
369
370 #[inline(always)]
371 fn write(&self, mut ptr: GenericArray<&mut Self::Block, U2>) {
372 use crate::simd::Identity;
373
374 let a0b0 = core::simd::simd_swizzle!(self.a, self.b, [0, 1, 2, 3, 8, 9, 10, 11]);
378 let a1b1 = core::simd::simd_swizzle!(self.a, self.b, [4, 5, 6, 7, 12, 13, 14, 15]);
379 let d0c0 = core::simd::simd_swizzle!(self.d, self.c, [0, 1, 2, 3, 8, 9, 10, 11]);
380 let d1c1 = core::simd::simd_swizzle!(self.d, self.c, [4, 5, 6, 7, 12, 13, 14, 15]);
381
382 *ptr[0] += Identity::<16>::concat_swizzle(a0b0, d0c0);
383 *ptr[1] += Identity::<16>::concat_swizzle(a1b1, d1c1);
384 }
385
386 #[inline(always)]
387 fn keystream<const ROUND_PAIRS: usize>(&mut self) {
388 if ROUND_PAIRS == 0 {
389 return;
390 }
391
392 for _ in 0..(ROUND_PAIRS * 2) {
393 self.b ^= simd_rotate_left::<_, 7>(self.a + self.d);
394 self.c ^= simd_rotate_left::<_, 9>(self.b + self.a);
395 self.d ^= simd_rotate_left::<_, 13>(self.c + self.b);
396 self.a ^= simd_rotate_left::<_, 18>(self.d + self.c);
397
398 self.d = core::simd::simd_swizzle!(self.d, [1, 2, 3, 0, 5, 6, 7, 4]);
399 self.c = core::simd::simd_swizzle!(self.c, [2, 3, 0, 1, 6, 7, 4, 5]);
400 self.b = core::simd::simd_swizzle!(self.b, [3, 0, 1, 2, 7, 4, 5, 6]);
401 (self.b, self.d) = (self.d, self.b);
402 }
403 }
404}
405
406#[cfg(test)]
407#[allow(unused_imports)]
408mod tests {
409 use generic_array::{GenericArray, typenum::U1};
410
411 use super::*;
412
413 pub(crate) fn test_shuffle_in_out_identity<S: Salsa20>()
414 where
415 S::Block: BlockType,
416 {
417 fn lfsr(x: &mut u32) -> u32 {
418 *x = *x ^ (*x >> 2);
419 *x = *x ^ (*x >> 3);
420 *x = *x ^ (*x >> 5);
421 *x
422 }
423
424 let mut state = 0;
425
426 for _ in 0..5 {
427 let test_input = Align64(core::array::from_fn(|i| lfsr(&mut state) + i as u32));
428
429 let mut result = test_input.clone();
430 S::shuffle_in(&mut result);
431 S::shuffle_out(&mut result);
432 assert_eq!(result, test_input);
433 }
434 }
435
436 #[cfg(feature = "portable-simd")]
437 fn test_keystream_portable_simd<const ROUND_PAIRS: usize>() {
438 test_shuffle_in_out_identity::<BlockPortableSimd>();
439
440 let test_input: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
441 let mut expected = test_input.clone();
442
443 let mut test_input_scalar_shuffled = test_input.clone();
444 BlockScalar::<U1>::shuffle_in(&mut test_input_scalar_shuffled);
445 let mut block =
446 BlockScalar::<U1>::read(GenericArray::from_array([&test_input_scalar_shuffled]));
447 block.keystream::<ROUND_PAIRS>();
448 block.write(GenericArray::from_array([&mut expected]));
449 BlockScalar::<U1>::shuffle_out(&mut expected);
450
451 let mut test_input_shuffled = test_input.clone();
452
453 BlockPortableSimd::shuffle_in(&mut test_input_shuffled);
454 let mut result = u32x16::from_array(*test_input_shuffled);
455
456 let mut block_v = BlockPortableSimd::read(GenericArray::from_array([&result]));
457 block_v.keystream::<ROUND_PAIRS>();
458 block_v.write(GenericArray::from_array([&mut result]));
459
460 let mut output = Align64(result.to_array());
461 BlockPortableSimd::shuffle_out(&mut output);
462
463 assert_eq!(output, expected);
464 }
465
466 #[cfg(feature = "portable-simd")]
467 fn test_keystream_portable_simd2<const ROUND_PAIRS: usize>() {
468 test_shuffle_in_out_identity::<BlockPortableSimd2>();
469
470 let test_input0: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32));
471 let test_input1: Align64<[u32; 16]> = Align64(core::array::from_fn(|i| i as u32 + 16));
472 let mut expected0 = test_input0.clone();
473 let mut expected1 = test_input1.clone();
474
475 let mut test_input0_scalar_shuffled = test_input0.clone();
476 let mut test_input1_scalar_shuffled = test_input1.clone();
477 BlockScalar::<U1>::shuffle_in(&mut test_input0_scalar_shuffled);
478 BlockScalar::<U1>::shuffle_in(&mut test_input1_scalar_shuffled);
479
480 let mut block0 =
481 BlockScalar::<U1>::read(GenericArray::from_array([&test_input0_scalar_shuffled]));
482 let mut block1 =
483 BlockScalar::<U1>::read(GenericArray::from_array([&test_input1_scalar_shuffled]));
484 block0.keystream::<ROUND_PAIRS>();
485 block1.keystream::<ROUND_PAIRS>();
486 block0.write(GenericArray::from_array([&mut expected0]));
487 block1.write(GenericArray::from_array([&mut expected1]));
488 BlockScalar::<U1>::shuffle_out(&mut expected0);
489 BlockScalar::<U1>::shuffle_out(&mut expected1);
490
491 let mut test_input0_shuffled = test_input0.clone();
492 let mut test_input1_shuffled = test_input1.clone();
493 BlockPortableSimd2::shuffle_in(&mut test_input0_shuffled);
494 BlockPortableSimd2::shuffle_in(&mut test_input1_shuffled);
495
496 let mut result0 = u32x16::from_array(*test_input0_shuffled);
497 let mut result1 = u32x16::from_array(*test_input1_shuffled);
498
499 let mut block_v0 = BlockPortableSimd2::read(GenericArray::from_array([&result0, &result1]));
500 block_v0.keystream::<ROUND_PAIRS>();
501 block_v0.write(GenericArray::from_array([&mut result0, &mut result1]));
502
503 let mut output0 = Align64(result0.to_array());
504 let mut output1 = Align64(result1.to_array());
505
506 BlockPortableSimd2::shuffle_out(&mut output0);
507 BlockPortableSimd2::shuffle_out(&mut output1);
508
509 assert_eq!(output0, expected0);
510 assert_eq!(output1, expected1);
511 }
512
513 #[cfg(feature = "portable-simd")]
514 #[test]
515 fn test_keystream_portable_simd_0() {
516 test_keystream_portable_simd::<0>();
517 }
518
519 #[cfg(feature = "portable-simd")]
520 #[test]
521 fn test_keystream_portable_simd_2() {
522 test_keystream_portable_simd::<1>();
523 }
524
525 #[cfg(feature = "portable-simd")]
526 #[test]
527 fn test_keystream_portable_simd_8() {
528 test_keystream_portable_simd::<4>();
529 }
530
531 #[cfg(feature = "portable-simd")]
532 #[test]
533 fn test_keystream_portable_simd_10() {
534 test_keystream_portable_simd::<5>();
535 }
536
537 #[cfg(feature = "portable-simd")]
538 #[test]
539 fn test_keystream_portable_simd2_0() {
540 test_keystream_portable_simd2::<0>();
541 }
542
543 #[cfg(feature = "portable-simd")]
544 #[test]
545 fn test_keystream_portable_simd2_2() {
546 test_keystream_portable_simd2::<1>();
547 }
548
549 #[cfg(feature = "portable-simd")]
550 #[test]
551 fn test_keystream_portable_simd2_8() {
552 test_keystream_portable_simd2::<4>();
553 }
554
555 #[cfg(feature = "portable-simd")]
556 #[test]
557 fn test_keystream_portable_simd2_10() {
558 test_keystream_portable_simd2::<5>();
559 }
560}