scrypt_opt/
fixed_r.rs

1use core::num::NonZeroU8;
2
3use generic_array::{
4    ArrayLength, GenericArray,
5    typenum::{B0, NonZero, U1, U2, UInt, Unsigned},
6};
7
8#[allow(unused_imports)]
9use crate::features::Feature as _;
10#[cfg(feature = "alloc")]
11use crate::memory;
12use crate::{
13    DefaultEngine1, DefaultEngine2, MAX_N, ScryptBlockMixInput, ScryptBlockMixOutput,
14    memory::Align64,
15    pbkdf2_1::Pbkdf2HmacSha256State,
16    pipeline::PipelineContext,
17    salsa20::{BlockType, Salsa20},
18};
19
20/// The type for one block for scrypt BlockMix operation (128 bytes/1R)
21pub type Block<R> = GenericArray<u8, Mul128<R>>;
22
23include!("block_mix.rs");
24
25#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
26const MAX_R_FOR_FULL_INTERLEAVED_ZMM: usize = 6; // 6 * 2 * 2 = 24 registers
27const MAX_R_FOR_UNROLLING: usize = 8;
28
29pub(crate) type Mul2<U> = UInt<U, B0>;
30pub(crate) type Mul4<U> = UInt<Mul2<U>, B0>;
31pub(crate) type Mul8<U> = UInt<Mul4<U>, B0>;
32pub(crate) type Mul16<U> = UInt<Mul8<U>, B0>;
33pub(crate) type Mul32<U> = UInt<Mul16<U>, B0>;
34pub(crate) type Mul64<U> = UInt<Mul32<U>, B0>;
35pub(crate) type Mul128<U> = UInt<Mul64<U>, B0>;
36
37macro_rules! integerify {
38    (<$r:ty> $x:expr) => {{
39        let input: &crate::fixed_r::Block<R> = $x;
40
41        debug_assert_eq!(
42            input.as_ptr().align_offset(64),
43            0,
44            "unexpected input alignment"
45        );
46        debug_assert_eq!(input.len(), Mul128::<R>::USIZE, "unexpected input length");
47        #[allow(unused_unsafe)]
48        unsafe {
49            input
50                .as_ptr()
51                .cast::<u8>()
52                .add(Mul128::<R>::USIZE - 64)
53                .cast::<u32>()
54                .read() as usize
55        }
56    }};
57}
58
59#[inline(always)]
60/// Convert a number of blocks to a Cost Factor (log2(N - 2))
61pub const fn length_to_cf(l: usize) -> u8 {
62    let v = (l.saturating_sub(2)) as u32;
63    ((32 - v.leading_zeros()) as u8).saturating_sub(1)
64}
65
66/// Get the minimum number of blocks required for a given Cost Factor ((1 << cf) + 2)
67#[inline(always)]
68pub const fn minimum_blocks(cf: NonZeroU8) -> usize {
69    let r = 1 << cf.get();
70    r + 2
71}
72
73#[derive(Debug, Clone, Copy)]
74#[repr(transparent)]
75/// Scrypt with a fixed R value
76pub struct BufferSet<
77    Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>,
78    R: ArrayLength + NonZero,
79> {
80    v: Q,
81    _r: core::marker::PhantomData<R>,
82}
83
84impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]> + Default, R: ArrayLength + NonZero>
85    Default for BufferSet<Q, R>
86{
87    #[inline(always)]
88    fn default() -> Self {
89        Self::new(Q::default())
90    }
91}
92
93impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero> AsRef<Q>
94    for BufferSet<Q, R>
95{
96    fn as_ref(&self) -> &Q {
97        &self.v
98    }
99}
100
101impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero> AsMut<Q>
102    for BufferSet<Q, R>
103{
104    fn as_mut(&mut self) -> &mut Q {
105        &mut self.v
106    }
107}
108
109#[cfg(feature = "alloc")]
110impl<R: ArrayLength + NonZero> BufferSet<alloc::vec::Vec<Align64<Block<R>>>, R> {
111    /// Create a new buffer set in a box with a given Cost Factor (log2(N))
112    #[inline(always)]
113    pub fn new_boxed(cf: core::num::NonZeroU8) -> alloc::boxed::Box<Self> {
114        let mut v = alloc::vec::Vec::new();
115        v.resize(minimum_blocks(cf), Align64::<Block<R>>::default());
116        alloc::boxed::Box::new(Self {
117            v,
118            _r: core::marker::PhantomData,
119        })
120    }
121}
122
123#[cfg(feature = "alloc")]
124impl<R: ArrayLength + NonZero> BufferSet<memory::MaybeHugeSlice<Align64<Block<R>>>, R> {
125    /// Create a new buffer set in a huge page with a given Cost Factor (log2(N))
126    #[inline(always)]
127    pub fn new_maybe_huge_slice(
128        cf: core::num::NonZeroU8,
129    ) -> BufferSet<memory::MaybeHugeSlice<Align64<Block<R>>>, R> {
130        use crate::memory;
131
132        BufferSet {
133            v: memory::MaybeHugeSlice::new_maybe(minimum_blocks(cf)),
134            _r: core::marker::PhantomData,
135        }
136    }
137
138    /// Create a new buffer set in a huge page with a given Cost Factor (log2(N))
139    #[inline(always)]
140    #[cfg(feature = "std")]
141    pub fn new_maybe_huge_slice_ex(
142        cf: core::num::NonZeroU8,
143    ) -> (
144        BufferSet<memory::MaybeHugeSlice<Align64<Block<R>>>, R>,
145        Option<std::io::Error>,
146    ) {
147        let (v, e) = memory::MaybeHugeSlice::new(minimum_blocks(cf));
148        (
149            BufferSet {
150                v,
151                _r: core::marker::PhantomData,
152            },
153            e,
154        )
155    }
156}
157
158impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero>
159    BufferSet<Q, R>
160{
161    /// Create a new buffer set
162    ///
163    /// # Panics
164    ///
165    /// Panics if the number of blocks is less than 4 or greater than MAX_N + 2.
166    pub fn new(q: Q) -> Self {
167        let l = q.as_ref().len();
168        assert!(l >= 4, "number of blocks must be at least 4");
169        assert!(
170            l - 2 <= MAX_N as usize,
171            "number of blocks must be at most MAX_N + 2"
172        );
173        Self {
174            v: q,
175            _r: core::marker::PhantomData,
176        }
177    }
178
179    /// Create a new buffer set if the number of blocks is between 4 and MAX_N + 2
180    ///
181    /// # Returns
182    ///
183    /// None if the number of blocks is less than 4 or greater than MAX_N + 2
184    pub fn try_new(q: Q) -> Option<Self> {
185        let l = q.as_ref().len();
186        if l < 4 {
187            return None;
188        }
189        if l > MAX_N as usize + 2 {
190            return None;
191        }
192        Some(Self {
193            v: q,
194            _r: core::marker::PhantomData,
195        })
196    }
197
198    /// Consume the buffer set and return the inner buffer
199    pub fn into_inner(self) -> Q {
200        self.v
201    }
202
203    /// Get the block buffer as 32-bit words
204    pub fn input_buffer(&self) -> &Align64<Block<R>> {
205        &self.v.as_ref()[0]
206    }
207
208    /// Get the block buffer mutably as 32-bit words
209    pub fn input_buffer_mut(&mut self) -> &mut Align64<Block<R>> {
210        &mut self.v.as_mut()[0]
211    }
212
213    /// Set the input for the block buffer
214    #[inline(always)]
215    pub fn set_input(&mut self, hmac_state: &Pbkdf2HmacSha256State, salt: &[u8]) {
216        hmac_state.emit_scatter(salt, [self.input_buffer_mut()]);
217    }
218
219    #[inline(always)]
220    /// Get the effective Cost Factor (log2(N)) for the buffer set
221    pub fn cf(&self) -> u8 {
222        let l = self.v.as_ref().len();
223
224        length_to_cf(l)
225    }
226
227    #[inline(always)]
228    /// Get the effective N value for the buffer set
229    pub fn n(&self) -> usize {
230        let cf = self.cf();
231        1 << cf
232    }
233
234    /// Get the raw salt output, useful for concatenation for P>1 cases
235    #[inline(always)]
236    pub fn raw_salt_output(&self) -> &Align64<Block<R>> {
237        unsafe { self.v.as_ref().get_unchecked(self.n()) }
238    }
239
240    /// Extract the output from the block buffer
241    #[inline(always)]
242    pub fn extract_output(&self, hmac_state: &Pbkdf2HmacSha256State, output: &mut [u8]) {
243        hmac_state.emit_gather([self.raw_salt_output()], output);
244    }
245
246    /// Shorten the buffer set into a smaller buffer set and return the remainder as a slice,
247    /// handy if you want to make a large allocation for the largest N you want to use and reuse it for multiple Cost Factors.
248    ///
249    /// # Returns
250    ///
251    /// None if the number of blocks is less than the minimum number of blocks for the given Cost Factor.
252    #[inline(always)]
253    pub fn shorten(
254        &mut self,
255        cf: NonZeroU8,
256    ) -> Option<(
257        BufferSet<&mut [Align64<Block<R>>], R>,
258        &mut [Align64<Block<R>>],
259    )> {
260        let min_blocks = minimum_blocks(cf);
261        let (set, rest) = self.v.as_mut().split_at_mut_checked(min_blocks)?;
262        Some((
263            BufferSet {
264                v: set,
265                _r: core::marker::PhantomData,
266            },
267            rest,
268        ))
269    }
270
271    /// Start an interleaved pipeline.
272    #[cfg_attr(
273        all(target_arch = "x86_64", not(target_feature = "avx2")),
274        scrypt_opt_derive::generate_target_variant("avx2")
275    )]
276    pub(super) fn ro_mix_front_ex<S: Salsa20<Lanes = U1>>(&mut self) {
277        let v = self.v.as_mut();
278        let n = 1 << length_to_cf(v.len());
279
280        // at least n+1 long, this is already enforced by length_to_cf so we can disable it for release builds
281        debug_assert!(v.len() > n, "ro_mix_front_ex: v.len() < n");
282
283        unsafe {
284            v.get_unchecked_mut(0)
285                .chunks_exact_mut(64)
286                .for_each(|chunk| {
287                    S::shuffle_in(
288                        chunk
289                            .as_mut_ptr()
290                            .cast::<Align64<[u32; 16]>>()
291                            .as_mut()
292                            .unwrap(),
293                    );
294                });
295        }
296
297        for i in 0..n {
298            let [src, dst] = unsafe { v.get_disjoint_unchecked_mut([i, i + 1]) };
299            block_mix!(R::USIZE; [<S> &*src => &mut *dst]);
300        }
301    }
302
303    /// Drain an interleaved pipeline.
304    #[cfg_attr(
305        all(target_arch = "x86_64", not(target_feature = "avx2")),
306        scrypt_opt_derive::generate_target_variant("avx2")
307    )]
308    pub(super) fn ro_mix_back_ex<S: Salsa20<Lanes = U1>>(&mut self) {
309        let v = self.v.as_mut();
310        let n = 1 << length_to_cf(v.len());
311        // at least n+2 long, this is already enforced by length_to_cf so we can disable it for release builds
312        debug_assert!(v.len() >= n + 2, "pipeline_end_ex: v.len() < n + 2");
313
314        for _ in (0..n).step_by(2) {
315            let idx = integerify!(<R> unsafe { v.get_unchecked(n) });
316
317            let j = idx & (n - 1);
318
319            // SAFETY: the largest j value is n-1, so the largest index of the 3 is n+1, which is in bounds after the >=n+2 check
320            let [in0, in1, out] = unsafe { v.get_disjoint_unchecked_mut([n, j, n + 1]) };
321            block_mix!(R::USIZE; [<S> (&*in0, &*in1) => &mut *out]);
322            let idx2 = integerify!(<R> unsafe { v.get_unchecked(n + 1) });
323
324            let j2 = idx2 & (n - 1);
325
326            // SAFETY: the largest j2 value is n-1, so the largest index of the 3 is n+1, which is in bounds after the >=n+2 check
327            let [b, v, t] = unsafe { v.get_disjoint_unchecked_mut([n, j2, n + 1]) };
328            block_mix!(R::USIZE; [<S> (&*v, &*t) => &mut *b]);
329        }
330
331        unsafe {
332            v.get_unchecked_mut(n)
333                .chunks_exact_mut(64)
334                .for_each(|chunk| {
335                    S::shuffle_out(
336                        chunk
337                            .as_mut_ptr()
338                            .cast::<Align64<[u32; 16]>>()
339                            .as_mut()
340                            .unwrap(),
341                    );
342                });
343        }
344    }
345
346    #[cfg_attr(
347        all(target_arch = "x86_64", not(target_feature = "avx2")),
348        scrypt_opt_derive::generate_target_variant("avx2")
349    )]
350    /// Interleaved RoMix operation.
351    ///
352    /// $RoMix_{Back}$ is performed on self and $RoMix_{Front}$ is performed on other.
353    ///
354    /// # Panics
355    ///
356    /// Panics if the buffers are of different equivalent Cost Factors.
357    pub(super) fn ro_mix_interleaved_ex<S: Salsa20<Lanes = U2>>(&mut self, other: &mut Self) {
358        let self_v = self.v.as_mut();
359        let other_v = other.v.as_mut();
360        let self_cf = length_to_cf(self_v.len());
361        let other_cf = length_to_cf(other_v.len());
362        assert_eq!(
363            self_cf, other_cf,
364            "ro_mix_interleaved_ex: self_cf != other_cf, are you passing two buffers of the same size?"
365        );
366        let n = 1 << self_cf;
367
368        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
369        debug_assert!(
370            other_v.len() >= n + 2,
371            "ro_mix_interleaved_ex: other_v.len() < n + 2"
372        );
373        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
374        debug_assert!(
375            other_v.len() >= n + 2,
376            "ro_mix_interleaved_ex: other_v.len() < n + 2"
377        );
378
379        // SAFETY: other_v is always 64-byte aligned
380        unsafe {
381            other_v
382                .get_unchecked_mut(0)
383                .chunks_exact_mut(64)
384                .for_each(|chunk| {
385                    S::shuffle_in(
386                        chunk
387                            .as_mut_ptr()
388                            .cast::<Align64<[u32; 16]>>()
389                            .as_mut()
390                            .unwrap_unchecked(),
391                    );
392                });
393        }
394
395        for i in (0..n).step_by(2) {
396            // SAFETY: the largest i value is n-1, so the largest index is n+1, which is in bounds after the >=n+2 check
397            let [src, middle, dst] =
398                unsafe { other_v.get_disjoint_unchecked_mut([i, i + 1, i + 2]) };
399
400            {
401                // Self: Compute T <- BlockMix(B ^ V[j])
402                // Other: Compute V[i+1] <- BlockMix(V[i])
403                let idx = integerify!(<R> unsafe { self_v.get_unchecked(n) });
404
405                let j = idx & (n - 1);
406
407                let [in0, in1, out] = unsafe { self_v.get_disjoint_unchecked_mut([j, n, n + 1]) };
408
409                block_mix!(R::USIZE; [<S> &*src => &mut *middle, <S> (&*in0, &*in1) => &mut *out]);
410            }
411
412            {
413                // Self: Compute B <- BlockMix(T ^ V[j'])
414                // Other: Compute V[i+2] <- BlockMix(V[i+1]) on last iteration it "naturally overflows" to V[n], so let B = V[n]
415                let idx2 = integerify!(<R> unsafe { self_v.get_unchecked(n + 1) });
416
417                let j2 = idx2 & (n - 1);
418                let [self_b, self_v, self_t] =
419                    unsafe { self_v.get_disjoint_unchecked_mut([n, j2, n + 1]) };
420
421                block_mix!(R::USIZE; [<S> &*middle => &mut *dst, <S> (&*self_v, &*self_t) => &mut *self_b]);
422            }
423        }
424
425        // SAFETY: self_v is always 64-byte aligned
426        unsafe {
427            self_v
428                .get_unchecked_mut(n)
429                .chunks_exact_mut(64)
430                .for_each(|chunk| {
431                    S::shuffle_out(
432                        chunk
433                            .as_mut_ptr()
434                            .cast::<Align64<[u32; 16]>>()
435                            .as_mut()
436                            .unwrap_unchecked(),
437                    );
438                });
439        }
440    }
441
442    /// Start an interleaved pipeline using the default engine by performing the $RoMix_{Front}$ operation.
443    #[inline(always)]
444    pub fn ro_mix_front(&mut self) {
445        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
446        {
447            if crate::features::Avx2.check() {
448                unsafe {
449                    self.ro_mix_front_ex_avx2::<crate::salsa20::x86_64::BlockAvx2>();
450                }
451                return;
452            }
453        }
454
455        self.ro_mix_front_ex::<DefaultEngine1>();
456    }
457
458    /// Drain an interleaved pipeline using the default engine by performing the $RoMix_{Back}$ operation.
459    #[inline(always)]
460    pub fn ro_mix_back(&mut self) {
461        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
462        {
463            if crate::features::Avx2.check() {
464                unsafe {
465                    self.ro_mix_back_ex_avx2::<crate::salsa20::x86_64::BlockAvx2>();
466                }
467                return;
468            }
469        }
470
471        self.ro_mix_back_ex::<DefaultEngine1>();
472    }
473
474    /// Perform the RoMix operation using the default engine.
475    pub fn scrypt_ro_mix(&mut self) {
476        // If possible, redirect to the register resident implementation to avoid data access thrashing.
477        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
478        if R::USIZE <= MAX_R_FOR_UNROLLING {
479            self.scrypt_ro_mix_ex_zmm::<crate::salsa20::x86_64::BlockAvx512F>();
480            return;
481        }
482
483        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
484        {
485            if crate::features::Avx2.check() {
486                unsafe {
487                    self.ro_mix_front_ex_avx2::<crate::salsa20::x86_64::BlockAvx2>();
488                    self.ro_mix_back_ex_avx2::<crate::salsa20::x86_64::BlockAvx2>();
489                }
490                return;
491            }
492        }
493
494        self.ro_mix_front_ex::<DefaultEngine1>();
495        self.ro_mix_back_ex::<DefaultEngine1>();
496    }
497
498    /// Perform the RoMix operation with interleaved buffers.
499    ///
500    /// $RoMix_{Back}$ is performed on self and $RoMix_{Front}$ is performed on other.
501    ///
502    /// # Panics
503    ///
504    /// Panics if the buffers are of different equivalent Cost Factors.
505    pub fn ro_mix_interleaved(&mut self, other: &mut Self) {
506        // If possible, steer to the register-resident AVX-512 implementation to avoid cache line thrashing.
507        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
508        if R::USIZE <= MAX_R_FOR_UNROLLING {
509            self.ro_mix_interleaved_ex_zmm::<crate::salsa20::x86_64::BlockAvx512FMb2>(other);
510            return;
511        }
512
513        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
514        {
515            if crate::features::Avx2.check() {
516                unsafe {
517                    self.ro_mix_interleaved_ex_avx2::<crate::salsa20::x86_64::BlockAvx2Mb2>(other);
518                }
519                return;
520            }
521        }
522
523        self.ro_mix_interleaved_ex::<DefaultEngine2>(other);
524    }
525
526    /// Pipeline RoMix operations on an iterator of inputs.
527    pub fn pipeline<K, S, C: PipelineContext<S, Q, R, K>, I: IntoIterator<Item = C>>(
528        &mut self,
529        other: &mut Self,
530        iter: I,
531        state: &mut S,
532    ) -> Option<K> {
533        let mut iter = iter.into_iter();
534
535        let (mut buffers0, mut buffers1) = (&mut *self, &mut *other);
536        let Some(mut input_m2) = iter.next() else {
537            return None;
538        };
539        input_m2.begin(state, buffers0);
540        let Some(mut input_m1) = iter.next() else {
541            buffers0.scrypt_ro_mix();
542            return input_m2.drain(state, buffers0);
543        };
544        input_m1.begin(state, buffers1);
545
546        #[cfg(all(not(test), target_arch = "x86_64", not(target_feature = "avx2")))]
547        {
548            if crate::features::Avx2.check() {
549                unsafe {
550                    buffers0.ro_mix_front_ex_avx2::<crate::salsa20::x86_64::BlockAvx2>();
551                    loop {
552                        buffers0
553                            .ro_mix_interleaved_ex_avx2::<crate::salsa20::x86_64::BlockAvx2Mb2>(
554                                buffers1,
555                            );
556                        if let Some(k) = input_m2.drain(state, buffers0) {
557                            return Some(k);
558                        }
559
560                        (buffers0, buffers1) = (buffers1, buffers0);
561
562                        let Some(mut input) = iter.next() else {
563                            break;
564                        };
565
566                        input.begin(state, buffers1);
567
568                        input_m2 = input_m1;
569                        input_m1 = input;
570                    }
571                    buffers0.ro_mix_back_ex_avx2::<crate::salsa20::x86_64::BlockAvx2>();
572                    return input_m1.drain(state, buffers0);
573                }
574            }
575        }
576
577        buffers0.ro_mix_front();
578        loop {
579            buffers0.ro_mix_interleaved(buffers1);
580            if let Some(k) = input_m2.drain(state, buffers0) {
581                return Some(k);
582            }
583
584            (buffers0, buffers1) = (buffers1, buffers0);
585
586            let Some(mut input) = iter.next() else {
587                break;
588            };
589
590            input.begin(state, buffers1);
591
592            input_m2 = input_m1;
593            input_m1 = input;
594        }
595        buffers0.ro_mix_back();
596        input_m1.drain(state, buffers0)
597    }
598}
599
600#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
601impl<Q: AsRef<[Align64<Block<R>>]> + AsMut<[Align64<Block<R>>]>, R: ArrayLength + NonZero>
602    BufferSet<Q, R>
603{
604    /// Perform the RoMix operation using AVX-512 registers as temporary storage.
605    #[inline(always)]
606    pub(super) fn scrypt_ro_mix_ex_zmm<
607        S: Salsa20<Lanes = U1, Block = core::arch::x86_64::__m512i>,
608    >(
609        &mut self,
610    ) {
611        assert!(
612            R::USIZE <= MAX_R_FOR_UNROLLING,
613            "scrypt_ro_mix_ex_zmm: R > {}",
614            MAX_R_FOR_UNROLLING
615        );
616        let v = self.v.as_mut();
617        let n = 1 << length_to_cf(v.len());
618        // at least n+1 long, this is checked by length_to_cf
619        debug_assert!(v.len() > n, "scrypt_ro_mix_ex_zmm: v.len() <= n");
620
621        unsafe {
622            v.get_unchecked_mut(0)
623                .chunks_exact_mut(64)
624                .for_each(|chunk| {
625                    S::shuffle_in(
626                        chunk
627                            .as_mut_ptr()
628                            .cast::<Align64<[u32; 16]>>()
629                            .as_mut()
630                            .unwrap(),
631                    );
632                });
633
634            let mut input_b = InRegisterAdapter::<R>::new();
635            for i in 0..(n - 1) {
636                let [src, dst] = v.get_disjoint_unchecked_mut([i, i + 1]);
637                block_mix!(R::USIZE; [<S> &*src => &mut *dst]);
638            }
639            block_mix!(R::USIZE; [<S> v.get_unchecked(n - 1) => &mut input_b]);
640
641            let mut idx = input_b.extract_idx() as usize & (n - 1);
642
643            for _ in (0..n).step_by(2) {
644                // for some reason this doesn't spill, so let's leave it as is
645                let mut input_t = InRegisterAdapter::<R>::new();
646                block_mix!(R::USIZE; [<S> (&input_b, v.get_unchecked(idx) ) => &mut input_t]);
647
648                idx = input_t.extract_idx() as usize & (n - 1);
649
650                block_mix!(R::USIZE; [<S> (&input_t, v.get_unchecked(idx)) => &mut input_b]);
651
652                idx = input_b.extract_idx() as usize & (n - 1);
653            }
654
655            // SAFETY: n is in bounds after the >=n+1 check
656            input_b.write_back(v.get_unchecked_mut(n));
657
658            v.get_unchecked_mut(n)
659                .chunks_exact_mut(64)
660                .for_each(|chunk| {
661                    S::shuffle_out(
662                        chunk
663                            .as_mut_ptr()
664                            .cast::<Align64<[u32; 16]>>()
665                            .as_mut()
666                            .unwrap(),
667                    );
668                });
669        }
670    }
671
672    /// Perform a paired-halves RoMix operation with interleaved buffers using AVX-512 registers as temporary storage for the latter (this) half pipeline.
673    ///
674    /// The former half is performed on `other` and the latter half is performed on `self`.
675    ///
676    /// # Panics
677    ///
678    /// Panics if the buffers are of different equivalent Cost Factors.
679    #[inline(always)]
680    pub(super) fn ro_mix_interleaved_ex_zmm<
681        S: Salsa20<Lanes = U2, Block = core::arch::x86_64::__m512i>,
682    >(
683        &mut self,
684        other: &mut Self,
685    ) {
686        assert!(
687            R::USIZE <= MAX_R_FOR_UNROLLING,
688            "ro_mix_interleaved_ex_zmm: R > {}",
689            MAX_R_FOR_UNROLLING
690        );
691        let self_v = self.v.as_mut();
692        let other_v = other.v.as_mut();
693
694        let self_cf = length_to_cf(self_v.len());
695        let other_cf = length_to_cf(other_v.len());
696        assert_eq!(
697            self_cf, other_cf,
698            "ro_mix_interleaved_ex_zmm: self_cf != other_cf, are you passing two buffers of the same size?"
699        );
700        let n = 1 << self_cf;
701
702        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
703        debug_assert!(
704            other_v.len() >= n + 1,
705            "ro_mix_interleaved_ex_zmm: other.v.len() < n + 1"
706        );
707        // at least n+2 long, this is already enforced by n() so we can disable it for release builds
708        debug_assert!(
709            self_v.len() >= n + 1,
710            "ro_mix_interleaved_ex_zmm: self.v.len() < n + 1"
711        );
712
713        unsafe {
714            other_v
715                .get_unchecked_mut(0)
716                .chunks_exact_mut(64)
717                .for_each(|chunk| {
718                    S::shuffle_in(
719                        chunk
720                            .as_mut_ptr()
721                            .cast::<Align64<[u32; 16]>>()
722                            .as_mut()
723                            .unwrap(),
724                    );
725                });
726        }
727
728        let mut idx = integerify!(<R> unsafe { self_v.get_unchecked(n) });
729        idx = idx & (n - 1);
730        let mut input_b =
731            InRegisterAdapter::<R>::init_with_block(unsafe { self_v.get_unchecked(n) });
732
733        for i in (0..n).step_by(2) {
734            let mut input_t = InRegisterAdapter::<R>::new();
735            // SAFETY: the largest i value is n-2, so the largest index is n, which is in bounds after the >=n+1 check
736            let [src, middle, dst] =
737                unsafe { other_v.get_disjoint_unchecked_mut([i, i + 1, i + 2]) };
738
739            let [self_vj, self_t] = unsafe { self_v.get_disjoint_unchecked_mut([idx, n + 1]) };
740            if R::USIZE <= MAX_R_FOR_FULL_INTERLEAVED_ZMM {
741                block_mix!(R::USIZE; [<S> &*src => &mut *middle, <S> (&*self_vj, &input_b) => &mut input_t]);
742                idx = input_t.extract_idx() as usize & (n - 1);
743            } else {
744                block_mix!(
745                    R::USIZE; [<S> &*src => &mut *middle, <S> (&*self_vj, &input_b) => &mut *self_t]
746                );
747                idx = integerify!(<R> self_t ) & (n - 1);
748            }
749
750            let [self_vj, self_t] = unsafe { self_v.get_disjoint_unchecked_mut([idx, n + 1]) };
751            {
752                if R::USIZE <= MAX_R_FOR_FULL_INTERLEAVED_ZMM {
753                    block_mix!(R::USIZE; [<S> &*middle => &mut *dst, <S> (&*self_vj, &input_t) => &mut input_b]);
754                } else {
755                    block_mix!(R::USIZE; [<S> &*middle => &mut *dst, <S> (&*self_vj, &*self_t) => &mut input_b]);
756                }
757
758                idx = input_b.extract_idx() as usize & (n - 1);
759            }
760        }
761
762        input_b.write_back(unsafe { self_v.get_unchecked_mut(n) });
763
764        unsafe {
765            self_v
766                .get_unchecked_mut(n)
767                .chunks_exact_mut(64)
768                .for_each(|chunk| {
769                    S::shuffle_out(
770                        chunk
771                            .as_mut_ptr()
772                            .cast::<Align64<[u32; 16]>>()
773                            .as_mut()
774                            .unwrap(),
775                    );
776                });
777        }
778    }
779}
780
781impl<'a, R: ArrayLength, B: BlockType> ScryptBlockMixInput<'a, B> for &'a Align64<Block<R>> {
782    #[inline(always)]
783    unsafe fn load(&self, word_idx: usize) -> B {
784        unsafe { B::read_from_ptr(self.as_ptr().add(word_idx * 64).cast()) }
785    }
786}
787impl<'a, R: ArrayLength, B: BlockType> ScryptBlockMixOutput<'a, R, B>
788    for &'a mut Align64<Block<R>>
789{
790    #[inline(always)]
791    fn store_even(&mut self, word_idx: usize, value: B) {
792        debug_assert!(word_idx * 64 < self.len() / 2);
793        unsafe { B::write_to_ptr(value, self.as_mut_ptr().add(word_idx * 64).cast()) }
794    }
795    #[inline(always)]
796    fn store_odd(&mut self, word_idx: usize, value: B) {
797        debug_assert!(Mul64::<R>::USIZE + word_idx * 64 < self.len());
798        unsafe {
799            B::write_to_ptr(
800                value,
801                self.as_mut_ptr()
802                    .add(Mul64::<R>::USIZE + word_idx * 64)
803                    .cast(),
804            )
805        }
806    }
807}
808
809#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
810#[repr(align(64))]
811struct InRegisterAdapter<R: ArrayLength> {
812    words: GenericArray<core::arch::x86_64::__m512i, Mul2<R>>,
813}
814
815#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
816impl<R: ArrayLength> InRegisterAdapter<R> {
817    #[inline(always)]
818    fn new() -> Self {
819        Self {
820            words: unsafe { core::mem::MaybeUninit::uninit().assume_init() },
821        }
822    }
823
824    #[inline(always)]
825    fn init_with_block(block: &Align64<crate::fixed_r::Block<R>>) -> Self {
826        use generic_array::sequence::GenericSequence;
827        Self {
828            words: unsafe {
829                GenericArray::generate(|i| {
830                    core::arch::x86_64::_mm512_load_si512(block.as_ptr().add(i * 64).cast())
831                })
832            },
833        }
834    }
835
836    #[inline(always)]
837    fn write_back(&mut self, output: &mut Align64<crate::fixed_r::Block<R>>) {
838        use core::arch::x86_64::*;
839        unsafe {
840            for i in 0..R::USIZE {
841                _mm512_store_si512(output.as_mut_ptr().add(i * 128).cast(), self.words[i * 2]);
842                _mm512_store_si512(
843                    output.as_mut_ptr().add(i * 128 + 64).cast(),
844                    self.words[i * 2 + 1],
845                );
846            }
847        }
848    }
849
850    #[inline(always)]
851    fn extract_idx(&self) -> u32 {
852        unsafe {
853            core::arch::x86_64::_mm_cvtsi128_si32(core::arch::x86_64::_mm512_castsi512_si128(
854                self.words[R::USIZE * 2 - 1],
855            )) as u32
856        }
857    }
858}
859
860#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
861impl<'a, R: ArrayLength> ScryptBlockMixInput<'a, core::arch::x86_64::__m512i>
862    for &'a InRegisterAdapter<R>
863{
864    #[inline(always)]
865    unsafe fn load(&self, word_idx: usize) -> core::arch::x86_64::__m512i {
866        self.words[word_idx]
867    }
868}
869
870#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
871impl<'a, R: ArrayLength> ScryptBlockMixOutput<'a, R, core::arch::x86_64::__m512i>
872    for &'a mut InRegisterAdapter<R>
873{
874    #[inline(always)]
875    fn store_even(&mut self, word_idx: usize, value: core::arch::x86_64::__m512i) {
876        self.words[word_idx] = value;
877    }
878    #[inline(always)]
879    fn store_odd(&mut self, word_idx: usize, value: core::arch::x86_64::__m512i) {
880        self.words[R::USIZE + word_idx] = value;
881    }
882}