Skip to main content

spg_crypto/
lib.rs

1//! BLAKE3 cryptographic hash — self-built single-thread implementation.
2//! Follows the spec at
3//! <https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf>.
4//!
5//! Scope: unkeyed `hash(input) -> [u8; 32]` only. KDF / keyed-hash modes
6//! are out of scope.
7//!
8//! v3.0.4 attempted a NEON-vectorised `compress` for aarch64 but the
9//! benchmark regressed 1.5–2× — see the comment on `fn compress`.
10//! The NEON path is kept under `#[cfg(test)]` as a cross-check oracle.
11#![no_std]
12// BLAKE3 intentionally splits a 64-bit counter into two 32-bit words and
13// writes a u32 block length that is always ≤ 64. Clippy's truncation warning
14// is correct in general but here the truncation is the protocol.
15#![allow(clippy::cast_possible_truncation)]
16// Workspace-wide `unsafe_code = "deny"` (v3.0.4 — was forbid). spg-crypto
17// is the one crate that needs unsafe for `std::arch::aarch64` /
18// `std::arch::x86_64` intrinsics; the allow is scoped to this crate
19// alone.
20#![allow(unsafe_code)]
21
22extern crate alloc;
23
24pub mod base64;
25pub mod crc32;
26pub mod hmac;
27pub mod lzss;
28pub mod pbkdf2;
29pub mod sha256;
30
31use alloc::vec::Vec;
32
33pub const OUT_LEN: usize = 32;
34const BLOCK_LEN: usize = 64;
35const CHUNK_LEN: usize = 1024;
36
37// Flag bits per the spec.
38const CHUNK_START: u32 = 1;
39const CHUNK_END: u32 = 2;
40const PARENT: u32 = 4;
41const ROOT: u32 = 8;
42
43const IV: [u32; 8] = [
44    0x6A09_E667,
45    0xBB67_AE85,
46    0x3C6E_F372,
47    0xA54F_F53A,
48    0x510E_527F,
49    0x9B05_688C,
50    0x1F83_D9AB,
51    0x5BE0_CD19,
52];
53
54/// Message word permutation applied between rounds (BLAKE3 spec §2.4).
55const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8];
56
57#[cfg(all(target_arch = "aarch64", test))]
58mod neon {
59    //! NEON (`uint32x4_t`) BLAKE3 compression for aarch64. Lays the
60    //! 16-word state out as four 128-bit vectors, runs the column +
61    //! diagonal rounds with vector add/xor/rotate, and stitches the
62    //! result back into a `[u32; 16]`. Bit-identical to the scalar
63    //! reference (cross-checked in the `neon_matches_scalar` unit
64    //! test).
65    use super::{IV, MSG_PERMUTATION};
66    use core::arch::aarch64::{
67        uint32x4_t, vaddq_u32, veorq_u32, vextq_u32, vld1q_u32, vld2q_u32, vsetq_lane_u32,
68        vshlq_n_u32, vshrq_n_u32, vst1q_u32,
69    };
70
71    /// Stable Rust forbids const arithmetic on generic const params
72    /// (`{ 32 - N }`), so we hand-roll a rotation per BLAKE3 amount
73    /// (16, 12, 8, 7) — there are exactly four.
74    #[inline]
75    unsafe fn vrotr16(x: uint32x4_t) -> uint32x4_t {
76        unsafe { veorq_u32(vshrq_n_u32::<16>(x), vshlq_n_u32::<16>(x)) }
77    }
78    #[inline]
79    unsafe fn vrotr12(x: uint32x4_t) -> uint32x4_t {
80        unsafe { veorq_u32(vshrq_n_u32::<12>(x), vshlq_n_u32::<20>(x)) }
81    }
82    #[inline]
83    unsafe fn vrotr8(x: uint32x4_t) -> uint32x4_t {
84        unsafe { veorq_u32(vshrq_n_u32::<8>(x), vshlq_n_u32::<24>(x)) }
85    }
86    #[inline]
87    unsafe fn vrotr7(x: uint32x4_t) -> uint32x4_t {
88        unsafe { veorq_u32(vshrq_n_u32::<7>(x), vshlq_n_u32::<25>(x)) }
89    }
90
91    /// Vectorised g-mixer applied lane-wise across (a, b, c, d) and a
92    /// pair of message vectors (mx, my). One call updates four
93    /// independent g operations in parallel.
94    #[inline]
95    unsafe fn g(
96        a: &mut uint32x4_t,
97        b: &mut uint32x4_t,
98        c: &mut uint32x4_t,
99        d: &mut uint32x4_t,
100        mx: uint32x4_t,
101        my: uint32x4_t,
102    ) {
103        unsafe {
104            *a = vaddq_u32(vaddq_u32(*a, *b), mx);
105            *d = vrotr16(veorq_u32(*d, *a));
106            *c = vaddq_u32(*c, *d);
107            *b = vrotr12(veorq_u32(*b, *c));
108            *a = vaddq_u32(vaddq_u32(*a, *b), my);
109            *d = vrotr8(veorq_u32(*d, *a));
110            *c = vaddq_u32(*c, *d);
111            *b = vrotr7(veorq_u32(*b, *c));
112        }
113    }
114
115    /// Run one BLAKE3 round — column then diagonal — over the 4-vector
116    /// state, gathering message words from `m` per the static layout.
117    /// Uses `vld2q_u32` for the de-interleaved `(mx, my)` pair (no
118    /// stack-array gather) and `vextq_u32` for the diagonal lane
119    /// rotations (single-cycle native ext instruction).
120    #[inline]
121    unsafe fn one_round(
122        v0: &mut uint32x4_t,
123        v1: &mut uint32x4_t,
124        v2: &mut uint32x4_t,
125        v3: &mut uint32x4_t,
126        m: &[u32; 16],
127    ) {
128        unsafe {
129            // Column round: lane i = (m[2i], m[2i+1]). vld2q de-interleaves
130            // 8 contiguous u32s into (.0 = evens, .1 = odds), exactly the
131            // shape we need.
132            let pair = vld2q_u32(m.as_ptr());
133            g(v0, v1, v2, v3, pair.0, pair.1);
134            // Diagonal round: rotate lanes by 1 / 2 / 3 with vextq_u32
135            // (compiles to one EXT instruction each), apply g, then
136            // rotate back.
137            let v1r = vextq_u32::<1>(*v1, *v1);
138            let v2r = vextq_u32::<2>(*v2, *v2);
139            let v3r = vextq_u32::<3>(*v3, *v3);
140            let mut v1r = v1r;
141            let mut v2r = v2r;
142            let mut v3r = v3r;
143            let pair = vld2q_u32(m[8..].as_ptr());
144            g(v0, &mut v1r, &mut v2r, &mut v3r, pair.0, pair.1);
145            // Unrotate: opposite-side EXT.
146            *v1 = vextq_u32::<3>(v1r, v1r);
147            *v2 = vextq_u32::<2>(v2r, v2r);
148            *v3 = vextq_u32::<1>(v3r, v3r);
149        }
150    }
151
152    /// NEON-vectorised compress. Same API as the scalar reference
153    /// (`compress_scalar`); bit-for-bit identical output.
154    #[target_feature(enable = "neon")]
155    pub unsafe fn compress(
156        chaining_value: &[u32; 8],
157        block_words: &[u32; 16],
158        counter: u64,
159        block_len: u32,
160        flags: u32,
161    ) -> [u32; 16] {
162        unsafe {
163            let mut v0 = vld1q_u32(chaining_value.as_ptr());
164            let mut v1 = vld1q_u32(chaining_value[4..].as_ptr());
165            let mut v2 = vld1q_u32(IV.as_ptr());
166            let mut v3 = vsetq_lane_u32::<0>(counter as u32, vld1q_u32(IV[4..].as_ptr()));
167            v3 = vsetq_lane_u32::<1>((counter >> 32) as u32, v3);
168            v3 = vsetq_lane_u32::<2>(block_len, v3);
169            v3 = vsetq_lane_u32::<3>(flags, v3);
170
171            let mut block = *block_words;
172            for round_idx in 0..7 {
173                one_round(&mut v0, &mut v1, &mut v2, &mut v3, &block);
174                if round_idx < 6 {
175                    let original = block;
176                    for i in 0..16 {
177                        block[i] = original[MSG_PERMUTATION[i]];
178                    }
179                }
180            }
181            // Output mixing per BLAKE3 spec §2.3:
182            //   state[i]     ^= state[i+8]
183            //   state[i+8]   ^= chaining_value[i]
184            v0 = veorq_u32(v0, v2);
185            v1 = veorq_u32(v1, v3);
186            v2 = veorq_u32(v2, vld1q_u32(chaining_value.as_ptr()));
187            v3 = veorq_u32(v3, vld1q_u32(chaining_value[4..].as_ptr()));
188
189            let mut out = [0u32; 16];
190            vst1q_u32(out.as_mut_ptr(), v0);
191            vst1q_u32(out[4..].as_mut_ptr(), v1);
192            vst1q_u32(out[8..].as_mut_ptr(), v2);
193            vst1q_u32(out[12..].as_mut_ptr(), v3);
194            out
195        }
196    }
197}
198
199#[inline]
200fn g(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize, mx: u32, my: u32) {
201    state[a] = state[a].wrapping_add(state[b]).wrapping_add(mx);
202    state[d] = (state[d] ^ state[a]).rotate_right(16);
203    state[c] = state[c].wrapping_add(state[d]);
204    state[b] = (state[b] ^ state[c]).rotate_right(12);
205    state[a] = state[a].wrapping_add(state[b]).wrapping_add(my);
206    state[d] = (state[d] ^ state[a]).rotate_right(8);
207    state[c] = state[c].wrapping_add(state[d]);
208    state[b] = (state[b] ^ state[c]).rotate_right(7);
209}
210
211fn round(state: &mut [u32; 16], m: &[u32; 16]) {
212    // Column.
213    g(state, 0, 4, 8, 12, m[0], m[1]);
214    g(state, 1, 5, 9, 13, m[2], m[3]);
215    g(state, 2, 6, 10, 14, m[4], m[5]);
216    g(state, 3, 7, 11, 15, m[6], m[7]);
217    // Diagonal.
218    g(state, 0, 5, 10, 15, m[8], m[9]);
219    g(state, 1, 6, 11, 12, m[10], m[11]);
220    g(state, 2, 7, 8, 13, m[12], m[13]);
221    g(state, 3, 4, 9, 14, m[14], m[15]);
222}
223
224fn permute(m: &mut [u32; 16]) {
225    let original = *m;
226    for i in 0..16 {
227        m[i] = original[MSG_PERMUTATION[i]];
228    }
229}
230
231/// Compression function (BLAKE3 spec §2.3). Returns the 16-word post-mix
232/// state; chaining uses the first 8 words.
233///
234/// v3.0.4 measured: a NEON implementation processing one block across
235/// 4 lanes regressed the bench by 1.5–2×. The reason — scalar BLAKE3
236/// is already heavily auto-vectorised by LLVM, and a within-block lane
237/// split adds 6 EXT permutes per round (42 extra instructions per
238/// compress) without buying parallelism. The real SIMD win for BLAKE3
239/// is 4-chunk-parallel compression, which doesn't apply to SPG's
240/// per-entry audit-log + per-small-catalog hash workload. The NEON
241/// path is kept (gated behind `#[cfg(test)]`) as a cross-check oracle
242/// only; runtime stays on scalar.
243fn compress(
244    chaining_value: &[u32; 8],
245    block_words: &[u32; 16],
246    counter: u64,
247    block_len: u32,
248    flags: u32,
249) -> [u32; 16] {
250    compress_scalar(chaining_value, block_words, counter, block_len, flags)
251}
252
253fn compress_scalar(
254    chaining_value: &[u32; 8],
255    block_words: &[u32; 16],
256    counter: u64,
257    block_len: u32,
258    flags: u32,
259) -> [u32; 16] {
260    let mut state = [
261        chaining_value[0],
262        chaining_value[1],
263        chaining_value[2],
264        chaining_value[3],
265        chaining_value[4],
266        chaining_value[5],
267        chaining_value[6],
268        chaining_value[7],
269        IV[0],
270        IV[1],
271        IV[2],
272        IV[3],
273        counter as u32,
274        (counter >> 32) as u32,
275        block_len,
276        flags,
277    ];
278    let mut block = *block_words;
279    round(&mut state, &block); // 1
280    permute(&mut block);
281    round(&mut state, &block); // 2
282    permute(&mut block);
283    round(&mut state, &block); // 3
284    permute(&mut block);
285    round(&mut state, &block); // 4
286    permute(&mut block);
287    round(&mut state, &block); // 5
288    permute(&mut block);
289    round(&mut state, &block); // 6
290    permute(&mut block);
291    round(&mut state, &block); // 7
292
293    // Output mixing — spec §2.3.
294    for i in 0..8 {
295        state[i] ^= state[i + 8];
296        state[i + 8] ^= chaining_value[i];
297    }
298    state
299}
300
301fn words_from_le_bytes(bytes: &[u8; BLOCK_LEN]) -> [u32; 16] {
302    let mut m = [0u32; 16];
303    for (i, chunk) in bytes.chunks_exact(4).enumerate() {
304        m[i] = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
305    }
306    m
307}
308
309fn bytes_from_le_words(words: &[u32; 8]) -> [u8; OUT_LEN] {
310    let mut out = [0u8; OUT_LEN];
311    for (i, w) in words.iter().enumerate() {
312        out[i * 4..(i + 1) * 4].copy_from_slice(&w.to_le_bytes());
313    }
314    out
315}
316
317/// Hash one chunk (≤ 1024 bytes). Returns the chunk's chaining value.
318fn hash_chunk(input: &[u8], chunk_counter: u64, is_root: bool, base_flags: u32) -> [u32; 8] {
319    debug_assert!(input.len() <= CHUNK_LEN);
320    let block_count = if input.is_empty() {
321        1
322    } else {
323        input.len().div_ceil(BLOCK_LEN)
324    };
325
326    let mut cv = IV;
327    for b_idx in 0..block_count {
328        let start = b_idx * BLOCK_LEN;
329        let end = core::cmp::min(start + BLOCK_LEN, input.len());
330        let mut block = [0u8; BLOCK_LEN];
331        if end > start {
332            block[..end - start].copy_from_slice(&input[start..end]);
333        }
334        let block_words = words_from_le_bytes(&block);
335        let block_len = (end - start) as u32;
336        let mut flags = base_flags;
337        if b_idx == 0 {
338            flags |= CHUNK_START;
339        }
340        if b_idx == block_count - 1 {
341            flags |= CHUNK_END;
342            if is_root {
343                flags |= ROOT;
344            }
345        }
346        let state = compress(&cv, &block_words, chunk_counter, block_len, flags);
347        cv.copy_from_slice(&state[..8]);
348    }
349    cv
350}
351
352/// Parent-node compression — counter is always 0, `block_len` always 64.
353fn parent_cv(left: &[u32; 8], right: &[u32; 8], is_root: bool, base_flags: u32) -> [u32; 8] {
354    let mut block_words = [0u32; 16];
355    block_words[..8].copy_from_slice(left);
356    block_words[8..].copy_from_slice(right);
357    let mut flags = base_flags | PARENT;
358    if is_root {
359        flags |= ROOT;
360    }
361    let state = compress(&IV, &block_words, 0, BLOCK_LEN as u32, flags);
362    let mut cv = [0u32; 8];
363    cv.copy_from_slice(&state[..8]);
364    cv
365}
366
367/// Hash a subtree (must contain ≥ 1 chunk worth of bytes when called from
368/// the top level via [`hash`]). Returns the subtree's chaining value.
369///
370/// BLAKE3 trees are left-balanced: at each internal node the left subtree
371/// holds the largest power-of-two chunks that still leave the right side
372/// non-empty.
373fn hash_subtree(input: &[u8], chunk_counter_base: u64, base_flags: u32) -> [u32; 8] {
374    if input.len() <= CHUNK_LEN {
375        return hash_chunk(input, chunk_counter_base, false, base_flags);
376    }
377    let total_chunks = input.len().div_ceil(CHUNK_LEN);
378    let left_chunks = largest_power_of_two_leq(total_chunks - 1);
379    let left_len = left_chunks * CHUNK_LEN;
380    let left = &input[..left_len];
381    let right = &input[left_len..];
382    let left_cv = hash_subtree(left, chunk_counter_base, base_flags);
383    let right_cv = hash_subtree(right, chunk_counter_base + left_chunks as u64, base_flags);
384    parent_cv(&left_cv, &right_cv, false, base_flags)
385}
386
387/// Largest power of two ≤ n, for n ≥ 1.
388fn largest_power_of_two_leq(n: usize) -> usize {
389    debug_assert!(n >= 1);
390    let bits = usize::BITS - 1 - n.leading_zeros();
391    1usize << bits
392}
393
394/// Top-level BLAKE3 hash. Returns the 32-byte digest.
395pub fn hash(input: &[u8]) -> [u8; OUT_LEN] {
396    let base_flags: u32 = 0;
397    if input.len() <= CHUNK_LEN {
398        let cv = hash_chunk(input, 0, true, base_flags);
399        return bytes_from_le_words(&cv);
400    }
401    // Multi-chunk: split + recurse, parent at root flags ROOT.
402    let total_chunks = input.len().div_ceil(CHUNK_LEN);
403    let left_chunks = largest_power_of_two_leq(total_chunks - 1);
404    let left_len = left_chunks * CHUNK_LEN;
405    let left = &input[..left_len];
406    let right = &input[left_len..];
407    let left_cv = hash_subtree(left, 0, base_flags);
408    let right_cv = hash_subtree(right, left_chunks as u64, base_flags);
409    let root_cv = parent_cv(&left_cv, &right_cv, true, base_flags);
410    bytes_from_le_words(&root_cv)
411}
412
413/// Helper: format a 32-byte digest as a lower-case hex string (no separators).
414/// Allocates a 64-character `String`. Useful for tests / human-facing logs.
415pub fn hex(digest: &[u8; OUT_LEN]) -> alloc::string::String {
416    const HEX: &[u8; 16] = b"0123456789abcdef";
417    let mut out = Vec::with_capacity(OUT_LEN * 2);
418    for &b in digest {
419        out.push(HEX[(b >> 4) as usize]);
420        out.push(HEX[(b & 0x0F) as usize]);
421    }
422    // We only emit ASCII chars, so the bytes are valid UTF-8.
423    alloc::string::String::from_utf8(out).expect("hex output is ASCII")
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use alloc::string::String;
430
431    fn h(s: &str) -> String {
432        hex(&hash(s.as_bytes()))
433    }
434
435    #[test]
436    fn empty_input_matches_blake3_kat() {
437        // Official BLAKE3 KAT for empty input.
438        assert_eq!(
439            h(""),
440            "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262"
441        );
442    }
443
444    #[test]
445    fn abc_matches_blake3_kat() {
446        assert_eq!(
447            h("abc"),
448            "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85"
449        );
450    }
451
452    #[cfg(target_arch = "aarch64")]
453    #[test]
454    fn neon_matches_scalar() {
455        // For every block size the hash() entry path could see, run
456        // a deterministic input through both the NEON dispatch (which
457        // hash() takes on aarch64) and the scalar reference directly,
458        // and confirm the two compressions agree bit-for-bit.
459        let cv = IV;
460        let block = [0xAA55_AA55u32; 16];
461        for counter in [0u64, 1, 0xFFFF_FFFFu64, u64::MAX] {
462            for &flags in &[0u32, CHUNK_START, CHUNK_END, ROOT, PARENT] {
463                for &block_len in &[0u32, 1, 32, 64] {
464                    let s = compress_scalar(&cv, &block, counter, block_len, flags);
465                    let n = unsafe { neon::compress(&cv, &block, counter, block_len, flags) };
466                    assert_eq!(
467                        s, n,
468                        "scalar vs NEON mismatch at counter={counter} flags={flags} block_len={block_len}"
469                    );
470                }
471            }
472        }
473        // Then sanity-check the public API: empty / abc inputs still
474        // land on the official KATs after the dispatch swap.
475        assert_eq!(
476            h(""),
477            "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262"
478        );
479        assert_eq!(
480            h("abc"),
481            "6437b3ac38465133ffb63b75273a8db548c558465d79db03fd359c6cd5bd9d85"
482        );
483    }
484
485    #[test]
486    fn deterministic() {
487        let input = b"hello world";
488        assert_eq!(hash(input), hash(input));
489    }
490
491    #[test]
492    fn one_byte_difference_changes_hash() {
493        assert_ne!(hash(b"abc"), hash(b"abd"));
494    }
495
496    #[test]
497    fn largest_power_of_two_helper() {
498        assert_eq!(largest_power_of_two_leq(1), 1);
499        assert_eq!(largest_power_of_two_leq(2), 2);
500        assert_eq!(largest_power_of_two_leq(3), 2);
501        assert_eq!(largest_power_of_two_leq(4), 4);
502        assert_eq!(largest_power_of_two_leq(7), 4);
503        assert_eq!(largest_power_of_two_leq(8), 8);
504        assert_eq!(largest_power_of_two_leq(1023), 512);
505    }
506}