scrypt_opt/
fixed_r.rs

1use core::num::NonZeroU8;
2
3use generic_array::{
4    ArrayLength, GenericArray,
5    typenum::{B0, NonZero, U1, U2, UInt, Unsigned},
6};
7
8#[allow(unused_imports)]
9use crate::features::Feature as _;
10#[cfg(feature = "alloc")]
11use crate::memory;
12use crate::{
13    DefaultEngine1, DefaultEngine2, MAX_N, ScryptBlockMixInput, ScryptBlockMixOutput,
14    memory::Align64,
15    pbkdf2_1::Pbkdf2HmacSha256State,
16    pipeline::PipelineContext,
17    salsa20::{BlockType, Salsa20},
18};
19
20/// The type for one block for scrypt BlockMix operation (128 bytes/1R)
21pub type Block<R> = GenericArray<u8, Mul128<R>>;
22
23include!("block_mix.rs");
24
25#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
26const MAX_R_FOR_FULL_INTERLEAVED_ZMM: usize = 6; // 6 * 2 * 2 = 24 registers
27const MAX_R_FOR_UNROLLING: usize = 8;
28
29pub(crate) type Mul2<U> = UInt<U, B0>;
30pub(crate) type Mul4<U> = UInt<Mul2<U>, B0>;
31pub(crate) type Mul8<U> = UInt<Mul4<U>, B0>;
32pub(crate) type Mul16<U> = UInt<Mul8<U>, B0>;
33pub(crate) type Mul32<U> = UInt<Mul16<U>, B0>;
34pub(crate) type Mul64<U> = UInt<Mul32<U>, B0>;
35pub(crate) type Mul128<U> = UInt<Mul64<U>, B0>;
36
37macro_rules! integerify {
38    (<$r:ty> $x:expr) => {{
39        let input: &crate::fixed_r::Block<R> = $x;
40
41        debug_assert_eq!(
42            input.as_ptr().align_offset(64),
43            0,
44            "unexpected input alignment"
45        );
46        debug_assert_eq!(input.len(), Mul128::<R>::USIZE, "unexpected input length");
47        #[allow(unused_unsafe)]
48        unsafe {
49            input
50                .as_ptr()
51                .cast::<u8>()
52                .add(Mul128::<R>::USIZE - 64)
53                .cast::<u32>()
54                .read() as usize
55        }
56    }};
57}
58
59#[inline(always)]
60/// Convert a number of blocks to a Cost Factor (log2(N - 2))
61pub const fn length_to_cf(l: usize) -> u8 {
62    let v = (l.saturating_sub(2)) as u32;
63    ((32 - v.leading_zeros()) as u8).saturating_sub(1)
64}
65
66/// Get the minimum number of blocks required for a given Cost Factor ((1 << cf) + 2)
67#[inline(always)]
68pub const fn minimum_blocks(cf: NonZeroU8) -> usize {
69    let r = 1 << cf.get();
70    r + 2
71}
72
73#[derive(Debug, Clone, Copy)]
74#[repr(transparent)]
75/// Scrypt with a fixed R value
76pub struct BufferSet<
77    Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>,
78    R: ArrayLength + NonZero,
79> {
80    v: Q,
81    _r: core::marker::PhantomData<R>,
82}
83
84impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]> + Default, R: ArrayLength + NonZero>
85    Default for BufferSet<Q, R>
86{
87    #[inline(always)]
88    fn default() -> Self {
89        Self::new(Q::default())
90    }
91}
92
93impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero> AsRef<Q>
94    for BufferSet<Q, R>
95{
96    fn as_ref(&self) -> &Q {
97        &self.v
98    }
99}
100
101impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero> AsMut<Q>
102    for BufferSet<Q, R>
103{
104    fn as_mut(&mut self) -> &mut Q {
105        &mut self.v
106    }
107}
108
109#[cfg(feature = "alloc")]
110impl<R: ArrayLength + NonZero> BufferSet<alloc::vec::Vec<Align64<Block<R>>>, R> {
111    /// Create a new buffer set in a box with a given Cost Factor (log2(N))
112    #[inline(always)]
113    pub fn new_boxed(cf: core::num::NonZeroU8) -> alloc::boxed::Box<Self> {
114        let mut v = alloc::vec::Vec::new();
115        v.resize(minimum_blocks(cf), Align64::<Block<R>>::default());
116        alloc::boxed::Box::new(Self {
117            v,
118            _r: core::marker::PhantomData,
119        })
120    }
121}
122
123#[cfg(feature = "alloc")]
124impl<R: ArrayLength + NonZero> BufferSet<memory::MaybeHugeSlice<Align64<Block<R>>>, R> {
125    /// Create a new buffer set in a huge page with a given Cost Factor (log2(N))
126    #[inline(always)]
127    pub fn new_maybe_huge_slice(
128        cf: core::num::NonZeroU8,
129    ) -> BufferSet<memory::MaybeHugeSlice<Align64<Block<R>>>, R> {
130        use crate::memory;
131
132        BufferSet {
133            v: memory::MaybeHugeSlice::new_maybe(minimum_blocks(cf)),
134            _r: core::marker::PhantomData,
135        }
136    }
137
138    /// Create a new buffer set in a huge page with a given Cost Factor (log2(N))
139    #[inline(always)]
140    #[cfg(feature = "std")]
141    pub fn new_maybe_huge_slice_ex(
142        cf: core::num::NonZeroU8,
143    ) -> (
144        BufferSet<memory::MaybeHugeSlice<Align64<Block<R>>>, R>,
145        Option<std::io::Error>,
146    ) {
147        let (v, e) = memory::MaybeHugeSlice::new(minimum_blocks(cf));
148        (
149            BufferSet {
150                v,
151                _r: core::marker::PhantomData,
152            },
153            e,
154        )
155    }
156}
157
158impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero>
159    BufferSet<Q, R>
160{
161    /// Create a new buffer set
162    ///
163    /// # Panics
164    ///
165    /// Panics if the number of blocks is less than 4 or greater than MAX_N + 2.
166    pub fn new(q: Q) -> Self {
167        let l = q.as_ref().len();
168        assert!(l >= 4, "number of blocks must be at least 4");
169        assert!(
170            l - 2 <= MAX_N as usize,
171            "number of blocks must be at most MAX_N + 2"
172        );
173        Self {
174            v: q,
175            _r: core::marker::PhantomData,
176        }
177    }
178
179    /// Create a new buffer set if the number of blocks is between 4 and MAX_N + 2
180    ///
181    /// # Returns
182    ///
183    /// None if the number of blocks is less than 4 or greater than MAX_N + 2
184    pub fn try_new(q: Q) -> Option<Self> {
185        let l = q.as_ref().len();
186        if l < 4 {
187            return None;
188        }
189        if l > MAX_N as usize + 2 {
190            return None;
191        }
192        Some(Self {
193            v: q,
194            _r: core::marker::PhantomData,
195        })
196    }
197
198    /// Consume the buffer set and return the inner buffer
199    pub fn into_inner(self) -> Q {
200        self.v
201    }
202
203    /// Get the block buffer as 32-bit words
204    pub fn input_buffer(&self) -> &Align64<Block<R>> {
205        &self.v.as_ref()[0]
206    }
207
208    /// Get the block buffer mutably as 32-bit words
209    pub fn input_buffer_mut(&mut self) -> &mut Align64<Block<R>> {
210        &mut self.v.as_mut()[0]
211    }
212
213    /// Set the input for the block buffer
214    #[inline(always)]
215    pub fn set_input(&mut self, hmac_state: &Pbkdf2HmacSha256State, salt: &[u8]) {
216        hmac_state.emit_scatter(salt, [self.input_buffer_mut()]);
217    }
218
219    #[inline(always)]
220    /// Get the effective Cost Factor (log2(N)) for the buffer set
221    pub fn cf(&self) -> u8 {
222        let l = self.v.as_ref().len();
223
224        length_to_cf(l)
225    }
226
227    #[inline(always)]
228    /// Get the effective N value for the buffer set
229    pub fn n(&self) -> usize {
230        let cf = self.cf();
231        1 << cf
232    }
233
234    /// Get the raw salt output, useful for concatenation for P>1 cases
235    #[inline(always)]
236    pub fn raw_salt_output(&self) -> &Align64<Block<R>> {
237        unsafe { self.v.as_ref().get_unchecked(self.n()) }
238    }
239
240    /// Extract the output from the block buffer
241    #[inline(always)]
242    pub fn extract_output(&self, hmac_state: &Pbkdf2HmacSha256State, output: &mut [u8]) {
243        hmac_state.emit_gather([self.raw_salt_output()], output);
244    }
245
246    /// Shorten the buffer set into a smaller buffer set and return the remainder as a slice,
247    /// handy if you want to make a large allocation for the largest N you want to use and reuse it for multiple Cost Factors.
248    ///
249    /// # Returns
250    ///
251    /// None if the number of blocks is less than the minimum number of blocks for the given Cost Factor.
252    #[inline(always)]
253    pub fn shorten(
254        &mut self,
255        cf: NonZeroU8,
256    ) -> Option<(
257        BufferSet<&mut [Align64<Block<R>>], R>,
258        &mut [Align64<Block<R>>],
259    )> {
260        let min_blocks = minimum_blocks(cf);
261        let (set, rest) = self.v.as_mut().split_at_mut_checked(min_blocks)?;
262        Some((
263            BufferSet {
264                v: set,
265                _r: core::marker::PhantomData,
266            },
267            rest,
268        ))
269    }
270
271    /// Start an interleaved pipeline.
272    #[cfg_attr(
273        all(target_arch = "x86_64", not(target_feature = "avx2")),
274        scrypt_opt_derive::generate_target_variant("avx2")
275    )]
276    pub(super) fn ro_mix_front_ex<S: Salsa20<Lanes = U1>>(&mut self) {
277        let v = self.v.as_mut();
278        let n = 1 << length_to_cf(v.len());
279
280        // at least n+1 long, this is already enforced by length_to_cf so we can disable it for release builds
281        debug_assert!(v.len() > n, "ro_mix_front_ex: v.len() < n");
282
283        unsafe {
284            v.get_unchecked_mut(0)
285                .chunks_exact_mut(64)
286                .for_each(|chunk| {
287                    S::shuffle_in(
288                        chunk
289                            .as_mut_ptr()
290                            .cast::<Align64<[u32; 16]>>()
291                            .as_mut()
292                            .unwrap(),
293                    );
294                });
295        }
296
297        for i in 0..n {
298            let [src, dst] = unsafe { v.get_disjoint_unchecked_mut([i, i + 1]) };
299            block_mix!(<R>; [<S> &*src => &mut *dst]);
300        }
301    }
302
303    /// Drain an interleaved pipeline.
304    #[cfg_attr(
305        all(target_arch = "x86_64", not(target_feature = "avx2")),
306        scrypt_opt_derive::generate_target_variant("avx2")
307    )]
308    pub(super) fn ro_mix_back_ex<S: Salsa20<Lanes = U1>>(&mut self) {
309        let v = self.v.as_mut();
310        let n = 1 << length_to_cf(v.len());
311        // at least n+2 long, this is already enforced by length_to_cf so we can disable it for release builds
312        debug_assert!(v.len() >= n + 2, "pipeline_end_ex: v.len() < n + 2");
313
314        for _ in (0..n).step_by(2) {
315            let idx = integerify!(<R> unsafe { v.get_unchecked(n) });
316
317            let j = idx & (n - 1);
318
319            // SAFETY: the largest j value is n-1, so the largest index of the 3 is n+1, which is in bounds after the >=n+2 check
320            let [in0, in1, out] = unsafe { v.get_disjoint_unchecked_mut([n, j, n + 1]) };
321            block_mix!(<R>; [<S> (&*in0, &*in1) => &mut *out]);
322            let idx2 = integerify!(<R> unsafe { v.get_unchecked(n + 1) });
323
324            let j2 = idx2 & (n - 1);
325
326            // SAFETY: the largest j2 value is n-1, so the largest index of the 3 is n+1, which is in bounds after the >=n+2 check
327            let [b, v, t] = unsafe { v.get_disjoint_unchecked_mut([n, j2, n + 1]) };
328            block_mix!(<R>; [<S> (&*v, &*t) => &mut *b]);
329        }
330
331        unsafe {
332            v.get_unchecked_mut(n)
333                .chunks_exact_mut(64)
334                .for_each(|chunk| {
335                    S::shuffle_out(
336                        chunk
337                            .as_mut_ptr()
338                            .cast::<Align64<[u32; 16]>>()
339                            .as_mut()
340                            .unwrap(),
341                    );
342                });
343        }
344    }
345
346    #[cfg_attr(
347        all(target_arch = "x86_64", not(target_feature = "avx2")),
348        scrypt_opt_derive::generate_target_variant("avx2")
349    )]
350    /// Interleaved RoMix operation.
351    ///
352    /// $RoMix_{Back}$ is performed on self and $RoMix_{Front}$ is performed on other.
353    ///
354    /// # Panics
355    ///
356    /// Panics if the buffers are of different equivalent Cost Factors.
357    pub(super) fn ro_mix_interleaved_ex<S: Salsa20<Lanes = U2>>(&mut self, other: &mut Self) {
358        let self_v = self.v.as_mut();
359        let other_v = other.v.as_mut();
360        let self_cf = length_to_cf(self_v.len());
361        let other_cf = length_to_cf(other_v.len());
362        assert_eq!(
363            self_cf, other_cf,
364            "ro_mix_interleaved_ex: self_cf != other_cf, are you passing two buffers of the same size?"
365        );
366        let n = 1 << self_cf;
367
368        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
369        debug_assert!(
370            other_v.len() >= n + 2,
371            "ro_mix_interleaved_ex: other_v.len() < n + 2"
372        );
373        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
374        debug_assert!(
375            other_v.len() >= n + 2,
376            "ro_mix_interleaved_ex: other_v.len() < n + 2"
377        );
378
379        // SAFETY: other_v is always 64-byte aligned
380        unsafe {
381            other_v
382                .get_unchecked_mut(0)
383                .chunks_exact_mut(64)
384                .for_each(|chunk| {
385                    S::shuffle_in(
386                        chunk
387                            .as_mut_ptr()
388                            .cast::<Align64<[u32; 16]>>()
389                            .as_mut()
390                            .unwrap_unchecked(),
391                    );
392                });
393        }
394
395        for i in (0..n).step_by(2) {
396            // SAFETY: the largest i value is n-1, so the largest index is n+1, which is in bounds after the >=n+2 check
397            let [src, middle, dst] =
398                unsafe { other_v.get_disjoint_unchecked_mut([i, i + 1, i + 2]) };
399
400            {
401                // Self: Compute T <- BlockMix(B ^ V[j])
402                // Other: Compute V[i+1] <- BlockMix(V[i])
403                let idx = integerify!(<R> unsafe { self_v.get_unchecked(n) });
404
405                let j = idx & (n - 1);
406
407                let [in0, in1, out] = unsafe { self_v.get_disjoint_unchecked_mut([j, n, n + 1]) };
408
409                block_mix!(<R>; [<S> &*src => &mut *middle, <S> (&*in0, &*in1) => &mut *out]);
410            }
411
412            {
413                // Self: Compute B <- BlockMix(T ^ V[j'])
414                // Other: Compute V[i+2] <- BlockMix(V[i+1]) on last iteration it "naturally overflows" to V[n], so let B = V[n]
415                let idx2 = integerify!(<R> unsafe { self_v.get_unchecked(n + 1) });
416
417                let j2 = idx2 & (n - 1);
418                let [self_b, self_v, self_t] =
419                    unsafe { self_v.get_disjoint_unchecked_mut([n, j2, n + 1]) };
420
421                block_mix!(<R>; [<S> &*middle => &mut *dst, <S> (&*self_v, &*self_t) => &mut *self_b]);
422            }
423        }
424
425        // SAFETY: self_v is always 64-byte aligned
426        unsafe {
427            self_v
428                .get_unchecked_mut(n)
429                .chunks_exact_mut(64)
430                .for_each(|chunk| {
431                    S::shuffle_out(
432                        chunk
433                            .as_mut_ptr()
434                            .cast::<Align64<[u32; 16]>>()
435                            .as_mut()
436                            .unwrap_unchecked(),
437                    );
438                });
439        }
440    }
441
442    /// Start an interleaved pipeline using the default engine by performing the $RoMix_{Front}$ operation.
443    pub fn ro_mix_front(&mut self) {
444        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
445        {
446            if crate::features::Avx2.check() {
447                unsafe {
448                    self.ro_mix_front_ex_avx2::<crate::salsa20::x86_64::BlockSse2<U1>>();
449                }
450                return;
451            }
452        }
453
454        self.ro_mix_front_ex::<DefaultEngine1>();
455    }
456
457    /// Drain an interleaved pipeline using the default engine by performing the $RoMix_{Back}$ operation.
458    pub fn ro_mix_back(&mut self) {
459        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
460        {
461            if crate::features::Avx2.check() {
462                unsafe {
463                    self.ro_mix_back_ex_avx2::<crate::salsa20::x86_64::BlockSse2<U1>>();
464                }
465                return;
466            }
467        }
468
469        self.ro_mix_back_ex::<DefaultEngine1>();
470    }
471
472    /// Perform the RoMix operation using the default engine.
473    pub fn scrypt_ro_mix(&mut self) {
474        // If possible, redirect to the register resident implementation to avoid data access thrashing.
475        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
476        if R::USIZE <= MAX_R_FOR_UNROLLING {
477            self.scrypt_ro_mix_ex_zmm::<crate::salsa20::x86_64::BlockAvx512F>();
478            return;
479        }
480
481        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
482        {
483            if crate::features::Avx2.check() {
484                unsafe {
485                    self.ro_mix_front_ex_avx2::<crate::salsa20::x86_64::BlockSse2<U1>>();
486                    self.ro_mix_back_ex_avx2::<crate::salsa20::x86_64::BlockSse2<U1>>();
487                }
488                return;
489            }
490        }
491
492        self.ro_mix_front_ex::<DefaultEngine1>();
493        self.ro_mix_back_ex::<DefaultEngine1>();
494    }
495
496    /// Perform the RoMix operation with interleaved buffers.
497    ///
498    /// $RoMix_{Back}$ is performed on self and $RoMix_{Front}$ is performed on other.
499    ///
500    /// # Panics
501    ///
502    /// Panics if the buffers are of different equivalent Cost Factors.
503    pub fn ro_mix_interleaved(&mut self, other: &mut Self) {
504        // If possible, steer to the register-resident AVX-512 implementation to avoid cache line thrashing.
505        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
506        if R::USIZE <= MAX_R_FOR_UNROLLING {
507            self.ro_mix_interleaved_ex_zmm::<crate::salsa20::x86_64::BlockAvx512FMb2>(other);
508            return;
509        }
510
511        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
512        {
513            if crate::features::Avx2.check() {
514                unsafe {
515                    self.ro_mix_interleaved_ex_avx2::<crate::salsa20::x86_64::BlockAvx2Mb2>(other);
516                }
517                return;
518            }
519        }
520
521        self.ro_mix_interleaved_ex::<DefaultEngine2>(other);
522    }
523
524    /// Pipeline RoMix operations on an iterator of inputs.
525    pub fn pipeline<K, S, C: PipelineContext<S, Q, R, K>, I: IntoIterator<Item = C>>(
526        &mut self,
527        other: &mut Self,
528        iter: I,
529        state: &mut S,
530    ) -> Option<K> {
531        let mut iter = iter.into_iter();
532
533        let (mut buffers0, mut buffers1) = (&mut *self, &mut *other);
534        let Some(mut input_m2) = iter.next() else {
535            return None;
536        };
537        input_m2.begin(state, buffers0);
538        let Some(mut input_m1) = iter.next() else {
539            buffers0.scrypt_ro_mix();
540            return input_m2.drain(state, buffers0);
541        };
542        input_m1.begin(state, buffers1);
543
544        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
545        {
546            if crate::features::Avx2.check() {
547                unsafe {
548                    buffers0.ro_mix_front_ex_avx2::<crate::salsa20::x86_64::BlockSse2<U1>>();
549                    loop {
550                        buffers0
551                            .ro_mix_interleaved_ex_avx2::<crate::salsa20::x86_64::BlockAvx2Mb2>(
552                                buffers1,
553                            );
554                        if let Some(k) = input_m2.drain(state, buffers0) {
555                            return Some(k);
556                        }
557
558                        (buffers0, buffers1) = (buffers1, buffers0);
559
560                        let Some(mut input) = iter.next() else {
561                            break;
562                        };
563
564                        input.begin(state, buffers1);
565
566                        input_m2 = input_m1;
567                        input_m1 = input;
568                    }
569                    buffers0.ro_mix_back_ex_avx2::<crate::salsa20::x86_64::BlockSse2<U1>>();
570                    return input_m1.drain(state, buffers0);
571                }
572            }
573        }
574
575        buffers0.ro_mix_front();
576        loop {
577            buffers0.ro_mix_interleaved(buffers1);
578            if let Some(k) = input_m2.drain(state, buffers0) {
579                return Some(k);
580            }
581
582            (buffers0, buffers1) = (buffers1, buffers0);
583
584            let Some(mut input) = iter.next() else {
585                break;
586            };
587
588            input.begin(state, buffers1);
589
590            input_m2 = input_m1;
591            input_m1 = input;
592        }
593        buffers0.ro_mix_back();
594        input_m1.drain(state, buffers0)
595    }
596}
597
598#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
599impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero>
600    BufferSet<Q, R>
601{
602    /// Perform the RoMix operation using AVX-512 registers as temporary storage.
603    pub(super) fn scrypt_ro_mix_ex_zmm<
604        S: Salsa20<Lanes = U1, Block = core::arch::x86_64::__m512i>,
605    >(
606        &mut self,
607    ) {
608        assert!(
609            R::USIZE <= MAX_R_FOR_UNROLLING,
610            "scrypt_ro_mix_ex_zmm: R > {}",
611            MAX_R_FOR_UNROLLING
612        );
613        let v = self.v.as_mut();
614        let n = 1 << length_to_cf(v.len());
615        // at least n+1 long, this is checked by length_to_cf
616        debug_assert!(v.len() > n, "scrypt_ro_mix_ex_zmm: v.len() <= n");
617
618        unsafe {
619            v.get_unchecked_mut(0)
620                .chunks_exact_mut(64)
621                .for_each(|chunk| {
622                    S::shuffle_in(
623                        chunk
624                            .as_mut_ptr()
625                            .cast::<Align64<[u32; 16]>>()
626                            .as_mut()
627                            .unwrap(),
628                    );
629                });
630
631            let mut input_b = InRegisterAdapter::<R>::new();
632            for i in 0..(n - 1) {
633                let [src, dst] = v.get_disjoint_unchecked_mut([i, i + 1]);
634                block_mix!(<R>; [<S> &*src => &mut *dst]);
635            }
636            block_mix!(<R>; [<S> v.get_unchecked(n - 1) => &mut input_b]);
637
638            let mut idx = input_b.extract_idx() as usize & (n - 1);
639
640            for _ in (0..n).step_by(2) {
641                // for some reason this doesn't spill, so let's leave it as is
642                let mut input_t = InRegisterAdapter::<R>::new();
643                block_mix!(<R>; [<S> (&input_b, v.get_unchecked(idx) ) => &mut input_t]);
644
645                idx = input_t.extract_idx() as usize & (n - 1);
646
647                block_mix!(<R>; [<S> (&input_t, v.get_unchecked(idx)) => &mut input_b]);
648
649                idx = input_b.extract_idx() as usize & (n - 1);
650            }
651
652            // SAFETY: n is in bounds after the >=n+1 check
653            input_b.write_back(v.get_unchecked_mut(n));
654
655            v.get_unchecked_mut(n)
656                .chunks_exact_mut(64)
657                .for_each(|chunk| {
658                    S::shuffle_out(
659                        chunk
660                            .as_mut_ptr()
661                            .cast::<Align64<[u32; 16]>>()
662                            .as_mut()
663                            .unwrap(),
664                    );
665                });
666        }
667    }
668
669    /// Perform a paired-halves RoMix operation with interleaved buffers using AVX-512 registers as temporary storage for the latter (this) half pipeline.
670    ///
671    /// The former half is performed on `other` and the latter half is performed on `self`.
672    ///
673    /// # Panics
674    ///
675    /// Panics if the buffers are of different equivalent Cost Factors.
676    pub(super) fn ro_mix_interleaved_ex_zmm<
677        S: Salsa20<Lanes = U2, Block = core::arch::x86_64::__m512i>,
678    >(
679        &mut self,
680        other: &mut Self,
681    ) {
682        assert!(
683            R::USIZE <= MAX_R_FOR_UNROLLING,
684            "ro_mix_interleaved_ex_zmm: R > {}",
685            MAX_R_FOR_UNROLLING
686        );
687        let self_v = self.v.as_mut();
688        let other_v = other.v.as_mut();
689
690        let self_cf = length_to_cf(self_v.len());
691        let other_cf = length_to_cf(other_v.len());
692        assert_eq!(
693            self_cf, other_cf,
694            "ro_mix_interleaved_ex_zmm: self_cf != other_cf, are you passing two buffers of the same size?"
695        );
696        let n = 1 << self_cf;
697
698        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
699        debug_assert!(
700            other_v.len() >= n + 1,
701            "ro_mix_interleaved_ex_zmm: other.v.len() < n + 1"
702        );
703        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
704        debug_assert!(
705            self_v.len() >= n + 1,
706            "ro_mix_interleaved_ex_zmm: self.v.len() < n + 1"
707        );
708
709        unsafe {
710            other_v
711                .get_unchecked_mut(0)
712                .chunks_exact_mut(64)
713                .for_each(|chunk| {
714                    S::shuffle_in(
715                        chunk
716                            .as_mut_ptr()
717                            .cast::<Align64<[u32; 16]>>()
718                            .as_mut()
719                            .unwrap(),
720                    );
721                });
722        }
723
724        let mut idx = integerify!(<R> unsafe { self_v.get_unchecked(n) });
725        idx = idx & (n - 1);
726        let mut input_b =
727            InRegisterAdapter::<R>::init_with_block(unsafe { self_v.get_unchecked(n) });
728
729        for i in (0..n).step_by(2) {
730            let mut input_t = InRegisterAdapter::<R>::new();
731            // SAFETY: the largest i value is n-2, so the largest index is n, which is in bounds after the >=n+1 check
732            let [src, middle, dst] =
733                unsafe { other_v.get_disjoint_unchecked_mut([i, i + 1, i + 2]) };
734
735            let [self_vj, self_t] = unsafe { self_v.get_disjoint_unchecked_mut([idx, n + 1]) };
736            if R::USIZE <= MAX_R_FOR_FULL_INTERLEAVED_ZMM {
737                block_mix!(<R>; [<S> &*src => &mut *middle, <S> (&*self_vj, &input_b) => &mut input_t]);
738                idx = input_t.extract_idx() as usize & (n - 1);
739            } else {
740                block_mix!(
741                    <R>; [<S> &*src => &mut *middle, <S> (&*self_vj, &input_b) => &mut *self_t]
742                );
743                idx = integerify!(<R> self_t ) & (n - 1);
744            }
745
746            let [self_vj, self_t] = unsafe { self_v.get_disjoint_unchecked_mut([idx, n + 1]) };
747            {
748                if R::USIZE <= MAX_R_FOR_FULL_INTERLEAVED_ZMM {
749                    block_mix!(<R>; [<S> &*middle => &mut *dst, <S> (&*self_vj, &input_t) => &mut input_b]);
750                } else {
751                    block_mix!(<R>; [<S> &*middle => &mut *dst, <S> (&*self_vj, &*self_t) => &mut input_b]);
752                }
753
754                idx = input_b.extract_idx() as usize & (n - 1);
755            }
756        }
757
758        input_b.write_back(unsafe { self_v.get_unchecked_mut(n) });
759
760        unsafe {
761            self_v
762                .get_unchecked_mut(n)
763                .chunks_exact_mut(64)
764                .for_each(|chunk| {
765                    S::shuffle_out(
766                        chunk
767                            .as_mut_ptr()
768                            .cast::<Align64<[u32; 16]>>()
769                            .as_mut()
770                            .unwrap(),
771                    );
772                });
773        }
774    }
775}
776
777impl<'a, R: ArrayLength, B: BlockType> ScryptBlockMixInput<'a, B> for &'a Align64<Block<R>> {
778    #[inline(always)]
779    unsafe fn load(&self, word_idx: usize) -> B {
780        unsafe { B::read_from_ptr(self.as_ptr().add(word_idx * 64).cast()) }
781    }
782}
783impl<'a, R: ArrayLength, B: BlockType> ScryptBlockMixOutput<'a, R, B>
784    for &'a mut Align64<Block<R>>
785{
786    #[inline(always)]
787    fn store_even(&mut self, word_idx: usize, value: B) {
788        debug_assert!(word_idx * 64 < self.len() / 2);
789        unsafe { B::write_to_ptr(value, self.as_mut_ptr().add(word_idx * 64).cast()) }
790    }
791    #[inline(always)]
792    fn store_odd(&mut self, word_idx: usize, value: B) {
793        debug_assert!(Mul64::<R>::USIZE + word_idx * 64 < self.len());
794        unsafe {
795            B::write_to_ptr(
796                value,
797                self.as_mut_ptr()
798                    .add(Mul64::<R>::USIZE + word_idx * 64)
799                    .cast(),
800            )
801        }
802    }
803}
804
805#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
806#[repr(align(64))]
807struct InRegisterAdapter<R: ArrayLength> {
808    words: GenericArray<core::arch::x86_64::__m512i, Mul2<R>>,
809}
810
811#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
812impl<R: ArrayLength> InRegisterAdapter<R> {
813    #[inline(always)]
814    fn new() -> Self {
815        Self {
816            words: unsafe { core::mem::MaybeUninit::uninit().assume_init() },
817        }
818    }
819
820    #[inline(always)]
821    fn init_with_block(block: &Align64<crate::fixed_r::Block<R>>) -> Self {
822        use generic_array::sequence::GenericSequence;
823        Self {
824            words: unsafe {
825                GenericArray::generate(|i| {
826                    core::arch::x86_64::_mm512_load_si512(block.as_ptr().add(i * 64).cast())
827                })
828            },
829        }
830    }
831
832    #[inline(always)]
833    fn write_back(&mut self, output: &mut Align64<crate::fixed_r::Block<R>>) {
834        use core::arch::x86_64::*;
835        unsafe {
836            for i in 0..R::USIZE {
837                _mm512_store_si512(output.as_mut_ptr().add(i * 128).cast(), self.words[i * 2]);
838                _mm512_store_si512(
839                    output.as_mut_ptr().add(i * 128 + 64).cast(),
840                    self.words[i * 2 + 1],
841                );
842            }
843        }
844    }
845
846    #[inline(always)]
847    fn extract_idx(&self) -> u32 {
848        unsafe {
849            core::arch::x86_64::_mm_cvtsi128_si32(core::arch::x86_64::_mm512_castsi512_si128(
850                self.words[R::USIZE * 2 - 1],
851            )) as u32
852        }
853    }
854}
855
856#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
857impl<'a, R: ArrayLength> ScryptBlockMixInput<'a, core::arch::x86_64::__m512i>
858    for &'a InRegisterAdapter<R>
859{
860    #[inline(always)]
861    unsafe fn load(&self, word_idx: usize) -> core::arch::x86_64::__m512i {
862        self.words[word_idx]
863    }
864}
865
866#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
867impl<'a, R: ArrayLength> ScryptBlockMixOutput<'a, R, core::arch::x86_64::__m512i>
868    for &'a mut InRegisterAdapter<R>
869{
870    #[inline(always)]
871    fn store_even(&mut self, word_idx: usize, value: core::arch::x86_64::__m512i) {
872        self.words[word_idx] = value;
873    }
874    #[inline(always)]
875    fn store_odd(&mut self, word_idx: usize, value: core::arch::x86_64::__m512i) {
876        self.words[R::USIZE + word_idx] = value;
877    }
878}