scrypt_opt/
compat.rs

1use core::num::{NonZeroU8, NonZeroU32};
2
3use generic_array::typenum::U1;
4#[cfg(target_arch = "wasm32")]
5use wasm_bindgen::prelude::*;
6
7use crate::{Align64, RoMix, fixed_r, memory::MaybeHugeSlice, pbkdf2_1::Pbkdf2HmacSha256State};
8
9/// API constants for unsupported parameters.
10pub const SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE: core::ffi::c_int = -1;
11
12/// API constants for invalid buffer sizes.
13pub const SCRYPT_OPT_INVALID_BUFFER_SIZE: core::ffi::c_int = -2;
14
15/// API constants for invalid buffer alignments.
16pub const SCRYPT_OPT_INVALID_BUFFER_ALIGNMENT: core::ffi::c_int = -3;
17
18/// Run scrypt with the given parameters and store the result in the output buffer.
19pub fn scrypt(
20    password: &[u8],
21    salt: &[u8],
22    log2_n: NonZeroU8,
23    r: NonZeroU32,
24    p: NonZeroU32,
25    output: &mut [u8],
26) {
27    let mut buffers0 = MaybeHugeSlice::<Align64<fixed_r::Block<U1>>>::new(
28        r.get() as usize * ((1 << log2_n.get()) + 2),
29    )
30    .0;
31
32    let hmac_state = Pbkdf2HmacSha256State::new(password);
33
34    hmac_state.emit_scatter(
35        salt,
36        buffers0
37            .ro_mix_input_buffer(r)
38            .chunks_exact_mut(core::mem::size_of::<Align64<fixed_r::Block<U1>>>())
39            .map(|chunk| unsafe {
40                chunk
41                    .as_mut_ptr()
42                    .cast::<Align64<fixed_r::Block<U1>>>()
43                    .as_mut()
44                    .unwrap()
45            }),
46    );
47    buffers0.ro_mix_front(r, log2_n);
48
49    if p.get() == 1 {
50        hmac_state.emit_gather(
51            buffers0
52                .ro_mix_back(r, log2_n)
53                .chunks_exact(core::mem::size_of::<Align64<fixed_r::Block<U1>>>())
54                .map(|block| unsafe {
55                    block
56                        .as_ptr()
57                        .cast::<Align64<fixed_r::Block<U1>>>()
58                        .as_ref()
59                        .unwrap()
60                }),
61            output,
62        );
63
64        return;
65    }
66
67    let mut output_hmac_state = hmac_state.clone();
68
69    let mut buffers1 = MaybeHugeSlice::<Align64<fixed_r::Block<U1>>>::new(
70        r.get() as usize * ((1 << log2_n.get()) + 2),
71    )
72    .0;
73
74    for chunk_idx in 1..p.get() {
75        hmac_state.emit_scatter_offset(
76            salt,
77            buffers1
78                .ro_mix_input_buffer(r)
79                .chunks_exact_mut(core::mem::size_of::<Align64<fixed_r::Block<U1>>>())
80                .map(|chunk| unsafe {
81                    chunk
82                        .as_mut_ptr()
83                        .cast::<Align64<fixed_r::Block<U1>>>()
84                        .as_mut()
85                        .unwrap()
86                }),
87            chunk_idx * 4 * r.get(),
88        );
89
90        let salt = buffers0.ro_mix_interleaved(&mut buffers1, r, log2_n);
91
92        output_hmac_state.ingest_salt(unsafe {
93            core::slice::from_raw_parts(
94                salt.as_ptr().cast::<Align64<fixed_r::Block<U1>>>(),
95                salt.len() / core::mem::size_of::<Align64<fixed_r::Block<U1>>>(),
96            )
97        });
98
99        (buffers0, buffers1) = (buffers1, buffers0);
100    }
101    output_hmac_state.emit_gather(
102        buffers0
103            .ro_mix_back(r, log2_n)
104            .chunks_exact(core::mem::size_of::<Align64<fixed_r::Block<U1>>>())
105            .map(|block| unsafe {
106                block
107                    .as_ptr()
108                    .cast::<Align64<fixed_r::Block<U1>>>()
109                    .as_ref()
110                    .unwrap()
111            }),
112        output,
113    );
114}
115
116#[unsafe(export_name = "scrypt_kdf_cf")]
117/// C export for scrypt_kdf using a libscrypt-kdf compatible API except input is taken as a cost factor.
118pub unsafe extern "C" fn scrypt_c_cf(
119    password: *const u8,
120    password_len: usize,
121    salt: *const u8,
122    salt_len: usize,
123    log2_n: u8,
124    r: u32,
125    p: u32,
126    output: *mut u8,
127    output_len: usize,
128) -> core::ffi::c_int {
129    let password = unsafe { core::slice::from_raw_parts(password, password_len) };
130    let salt = unsafe { core::slice::from_raw_parts(salt, salt_len) };
131    let output = unsafe { core::slice::from_raw_parts_mut(output, output_len) };
132    let Some(r) = NonZeroU32::new(r) else {
133        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
134    };
135    let Some(p) = NonZeroU32::new(p) else {
136        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
137    };
138    let Some(log2_n) = NonZeroU8::new(log2_n) else {
139        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
140    };
141    scrypt(password, salt, log2_n, r, p, output);
142    0
143}
144
145#[unsafe(export_name = "scrypt_kdf")]
146/// C export for scrypt_kdf using a libscrypt-kdf compatible API.
147pub unsafe extern "C" fn scrypt_c(
148    password: *const u8,
149    password_len: usize,
150    salt: *const u8,
151    salt_len: usize,
152    n: u64,
153    r: u32,
154    p: u32,
155    output: *mut u8,
156    output_len: usize,
157) -> core::ffi::c_int {
158    let log2_n = n.trailing_zeros();
159    if log2_n == 0 || 1 << log2_n != n {
160        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
161    }
162    let Some(log2_n) = NonZeroU8::new(log2_n as u8) else {
163        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
164    };
165    let password = unsafe { core::slice::from_raw_parts(password, password_len) };
166    let salt = unsafe { core::slice::from_raw_parts(salt, salt_len) };
167    let output = unsafe { core::slice::from_raw_parts_mut(output, output_len) };
168    let Some(r) = NonZeroU32::new(r) else {
169        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
170    };
171    let Some(p) = NonZeroU32::new(p) else {
172        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
173    };
174    scrypt(password, salt, log2_n, r, p, output);
175    0
176}
177
178/// Compute the minimum buffer length in bytes to allocate for [`scrypt_ro_mix`].
179///
180/// Returns 0 if the parameters are invalid or unsupported.
181///
182/// ```c
183/// #include <stdlib.h>
184///
185/// extern size_t scrypt_ro_mix_minimum_buffer_len(unsigned int r, unsigned int cf);
186///
187/// int main() {
188///     int minimum_buffer_len = scrypt_ro_mix_minimum_buffer_len(1, 1);
189///     printf("Minimum buffer length: %d\n", minimum_buffer_len);
190///     if (!minimum_buffer_len) {
191///         return 1;
192///     }
193///     void* alloc = aligned_alloc(64, minimum_buffer_len);
194///     if (alloc == NULL) {
195///         return 2;
196///     }
197///     scrypt_ro_mix(alloc, alloc, r, cf, minimum_buffer_len);
198///     return 0;
199/// }
200/// ```
201#[unsafe(export_name = "scrypt_ro_mix_minimum_buffer_len")]
202unsafe extern "C" fn scrypt_ro_mix_minimum_buffer_len(
203    r: core::ffi::c_uint,
204    cf: core::ffi::c_uint,
205) -> usize {
206    let Ok(cf) = cf.try_into() else {
207        return 0;
208    };
209    let Some(cf) = NonZeroU8::new(cf) else {
210        return 0;
211    };
212
213    128 * r as usize * ((1 << cf.get()) + 2)
214}
215
216/// C export for scrypt_ro_mix.
217///
218/// Parameters:
219/// - `front_buffer`: In. Pointer to the buffer to perform the RoMix_front operation on. Can be null.
220/// - `back_buffer`: In. Pointer to the buffer to perform the RoMix_back operation on. Can be null. Cannot be an alias of the front buffer.
221/// - `salt_output`: Out. Pointer to receive a pointer to the raw salt that corresponds to the back buffer. Can be null.
222/// - `r`: In. R value.
223/// - `cf`: In. Cost factor.
224/// - `minimum_buffer_size`: In. The smaller of the two buffer sizes.
225///
226/// Returns:
227/// - 0 on success.
228/// - `SCRYPT_OPT_INVALID_BUFFER_SIZE` if the parameters are invalid.
229/// - `SCRYPT_OPT_INVALID_BUFFER_ALIGNMENT` if the buffers are not 64-byte aligned.
230/// - `SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE` if the parameters are unsupported.
231#[unsafe(export_name = "scrypt_ro_mix")]
232unsafe extern "C" fn scrypt_ro_mix(
233    front_buffer: *mut u8,
234    back_buffer: *mut u8,
235    salt_output: *mut *const u8,
236    r: u32,
237    cf: u8,
238    minimum_buffer_size: usize,
239) -> core::ffi::c_int {
240    if r == 0 {
241        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
242    }
243
244    let Some(r) = NonZeroU32::new(r) else {
245        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
246    };
247
248    let Some(cf) = NonZeroU8::new(cf) else {
249        return SCRYPT_OPT_UNSUPPORTED_PARAM_SPACE;
250    };
251
252    // if both buffers are null, we can't do anything
253    if front_buffer.is_null() && back_buffer.is_null() {
254        return SCRYPT_OPT_INVALID_BUFFER_SIZE;
255    }
256
257    // if the front buffer is not null, it must be 64-byte aligned
258    if !front_buffer.is_null() && front_buffer.align_offset(64) != 0 {
259        return SCRYPT_OPT_INVALID_BUFFER_ALIGNMENT;
260    }
261    // if the back buffer is not null, it must be 64-byte aligned
262    if !back_buffer.is_null() && back_buffer.align_offset(64) != 0 {
263        return SCRYPT_OPT_INVALID_BUFFER_ALIGNMENT;
264    }
265
266    // if the back buffer is null, the salt output must be null
267    if back_buffer.is_null() && !salt_output.is_null() {
268        return SCRYPT_OPT_INVALID_BUFFER_SIZE;
269    }
270
271    let available_blocks =
272        minimum_buffer_size / core::mem::size_of::<Align64<fixed_r::Block<U1>>>();
273
274    let minimum_blocks = r.get() as usize * ((1 << cf.get()) + 2);
275    if available_blocks < minimum_blocks {
276        return SCRYPT_OPT_INVALID_BUFFER_SIZE;
277    }
278
279    if front_buffer.is_null() {
280        let mut buffer_back = unsafe {
281            core::slice::from_raw_parts_mut(
282                back_buffer.cast::<Align64<fixed_r::Block<U1>>>(),
283                minimum_blocks,
284            )
285        };
286        let salt_output_out = buffer_back.ro_mix_back(r, cf);
287        if !salt_output.is_null() {
288            unsafe {
289                *salt_output = salt_output_out.as_ptr().cast();
290            }
291        }
292    } else if back_buffer.is_null() {
293        let mut buffer_front = unsafe {
294            core::slice::from_raw_parts_mut(
295                front_buffer.cast::<Align64<fixed_r::Block<U1>>>(),
296                minimum_blocks,
297            )
298        };
299        buffer_front.ro_mix_front(r, cf);
300    } else {
301        let mut buffer_back = unsafe {
302            core::slice::from_raw_parts_mut(
303                back_buffer.cast::<Align64<fixed_r::Block<U1>>>(),
304                minimum_blocks,
305            )
306        };
307
308        let salt_output_out = if back_buffer == front_buffer {
309            buffer_back.ro_mix_front(r, cf);
310            buffer_back.ro_mix_back(r, cf)
311        } else {
312            let mut buffer_front = unsafe {
313                core::slice::from_raw_parts_mut(
314                    front_buffer.cast::<Align64<fixed_r::Block<U1>>>(),
315                    minimum_blocks,
316                )
317            };
318
319            buffer_back.ro_mix_interleaved(&mut buffer_front, r, cf)
320        };
321
322        if !salt_output.is_null() {
323            unsafe {
324                *salt_output = salt_output_out.as_ptr().cast();
325            }
326        }
327    }
328
329    0
330}
331
332#[cfg(target_arch = "wasm32")]
333#[wasm_bindgen(js_name = "scrypt")]
334/// WASM bindings for scrypt, it's not really (much) faster on SIMD due to the complete lack of wide SIMD support, just a wrapper for API compatibility.
335pub fn scrypt_wasm(password: &[u8], salt: &[u8], n: u32, r: u32, p: u32, dklen: usize) -> String {
336    let log2_n = NonZeroU8::new(n.trailing_zeros() as u8).unwrap();
337    if log2_n.get() as u32 >= r * 16 {
338        return String::from("Invalid r");
339    }
340    if p as u64 > ((u32::max_value() as u64 - 1) * 32) / (128 * (r as u64)) {
341        return String::from("Invalid p");
342    }
343    if dklen == 0 {
344        return String::from("dklen must be non-zero");
345    }
346
347    let mut result: Vec<u8> = vec![0; dklen * 2];
348    let Some(r) = NonZeroU32::new(r) else {
349        return String::from("Unsupported r value");
350    };
351    let Some(p) = NonZeroU32::new(p) else {
352        return String::from("Unsupported p value");
353    };
354
355    scrypt(password, salt, log2_n, r, p, &mut result[dklen..]);
356    for i in 0..dklen {
357        let word = result[dklen + i];
358        let high_nibble = (word >> 4) as u8;
359        let low_nibble = word & 0b1111;
360        result[i * 2] = if high_nibble < 10 {
361            b'0' + high_nibble
362        } else {
363            b'a' + high_nibble - 10
364        };
365        result[i * 2 + 1] = if low_nibble < 10 {
366            b'0' + low_nibble
367        } else {
368            b'a' + low_nibble - 10
369        };
370    }
371    unsafe { String::from_utf8_unchecked(result) }
372}
373
374#[cfg(test)]
375mod tests {
376    use generic_array::{
377        ArrayLength,
378        typenum::{NonZero, U1, U2, U8, U16, U32},
379    };
380
381    use crate::pbkdf2_1::Pbkdf2HmacSha256State;
382
383    use super::*;
384
385    #[test]
386    fn test_scrypt_api() {
387        for r in [1, 2, 4, 8, 16, 32] {
388            for p in 1..=4 {
389                for cf in 1..=7 {
390                    let mut output = [0; 64];
391                    let mut expected = [0; 64];
392                    scrypt(
393                        b"password",
394                        b"salt",
395                        cf.try_into().unwrap(),
396                        r.try_into().unwrap(),
397                        p.try_into().unwrap(),
398                        &mut output,
399                    );
400
401                    let params = ::scrypt::Params::new(cf, r, p, 64).unwrap();
402                    ::scrypt::scrypt(b"password", b"salt", &params, &mut expected).unwrap();
403
404                    assert_eq!(
405                        output, expected,
406                        "unexpected output at r={r}, p={p}, cf={cf}"
407                    );
408                }
409            }
410        }
411    }
412
413    fn test_scrypt_ro_mix_api<R: ArrayLength + NonZero>() {
414        const CF: u8 = 10;
415        unsafe {
416            let mut reference_buffer0 =
417                crate::fixed_r::BufferSet::<_, R>::new_boxed(CF.try_into().unwrap());
418            let mut reference_buffer1 =
419                crate::fixed_r::BufferSet::<_, R>::new_boxed(CF.try_into().unwrap());
420            reference_buffer0.set_input(&Pbkdf2HmacSha256State::new(b"password0"), b"salt");
421            reference_buffer1.set_input(&Pbkdf2HmacSha256State::new(b"password1"), b"salt");
422
423            let min_buffer_len = scrypt_ro_mix_minimum_buffer_len(R::U32, CF as u32);
424            let layout = alloc::alloc::Layout::from_size_align(min_buffer_len, 64).unwrap();
425
426            let alloc0 = alloc::alloc::alloc(layout);
427            assert!(!alloc0.is_null());
428            let alloc0 = core::slice::from_raw_parts_mut(alloc0, min_buffer_len);
429            let alloc1 = alloc::alloc::alloc(layout);
430            assert!(!alloc1.is_null());
431            let alloc1 = core::slice::from_raw_parts_mut(alloc1, min_buffer_len);
432
433            let input_slice = reference_buffer1.input_buffer();
434            alloc1[..input_slice.len()].copy_from_slice(input_slice);
435            let input_slice = reference_buffer0.input_buffer();
436            alloc0[..input_slice.len()].copy_from_slice(input_slice);
437
438            scrypt_ro_mix(
439                alloc0.as_mut_ptr().cast(),
440                core::ptr::null_mut(),
441                core::ptr::null_mut(),
442                R::U32,
443                CF,
444                min_buffer_len,
445            );
446
447            let mut alloc0_salt_output = core::ptr::null();
448
449            assert_eq!(
450                scrypt_ro_mix(
451                    alloc1.as_mut_ptr().cast(),
452                    alloc0.as_mut_ptr().cast(),
453                    &mut alloc0_salt_output,
454                    R::U32,
455                    CF,
456                    min_buffer_len,
457                ),
458                0
459            );
460
461            let mut alloc1_salt_output = core::ptr::null();
462
463            assert_eq!(
464                scrypt_ro_mix(
465                    core::ptr::null_mut(),
466                    alloc1.as_mut_ptr().cast(),
467                    &mut alloc1_salt_output,
468                    R::U32,
469                    CF,
470                    min_buffer_len,
471                ),
472                0
473            );
474
475            reference_buffer0.scrypt_ro_mix();
476            reference_buffer1.scrypt_ro_mix();
477            assert_eq!(
478                core::slice::from_raw_parts(
479                    alloc0_salt_output,
480                    reference_buffer0.raw_salt_output().len()
481                ),
482                reference_buffer0.raw_salt_output().as_slice()
483            );
484            assert_eq!(
485                core::slice::from_raw_parts(
486                    alloc1_salt_output,
487                    reference_buffer1.raw_salt_output().len()
488                ),
489                reference_buffer1.raw_salt_output().as_slice()
490            );
491
492            alloc::alloc::dealloc(alloc0.as_mut_ptr().cast(), layout);
493            alloc::alloc::dealloc(alloc1.as_mut_ptr().cast(), layout);
494        }
495    }
496
497    #[test]
498    fn test_scrypt_ro_mix_api_1() {
499        test_scrypt_ro_mix_api::<U1>();
500    }
501
502    #[test]
503    fn test_scrypt_ro_mix_api_2() {
504        test_scrypt_ro_mix_api::<U2>();
505    }
506
507    #[test]
508    fn test_scrypt_ro_mix_api_8() {
509        test_scrypt_ro_mix_api::<U8>();
510    }
511
512    #[test]
513    fn test_scrypt_ro_mix_api_16() {
514        test_scrypt_ro_mix_api::<U16>();
515    }
516
517    #[test]
518    fn test_scrypt_ro_mix_api_32() {
519        test_scrypt_ro_mix_api::<U32>();
520    }
521}