twox_hash/xxhash3/
streaming.rs

1use core::hint::assert_unchecked;
2
3use super::{large::INITIAL_ACCUMULATORS, *};
4
5/// A buffer containing the secret bytes.
6///
7/// # Safety
8///
9/// Must always return a slice with the same number of elements.
10pub unsafe trait FixedBuffer: AsRef<[u8]> {}
11
12/// A mutable buffer to contain the secret bytes.
13///
14/// # Safety
15///
16/// Must always return a slice with the same number of elements. The
17/// slice must always be the same as that returned from
18/// [`AsRef::as_ref`][].
19pub unsafe trait FixedMutBuffer: FixedBuffer + AsMut<[u8]> {}
20
21// Safety: An array will never change size.
22unsafe impl<const N: usize> FixedBuffer for [u8; N] {}
23
24// Safety: An array will never change size.
25unsafe impl<const N: usize> FixedMutBuffer for [u8; N] {}
26
27// Safety: An array will never change size.
28unsafe impl<const N: usize> FixedBuffer for &[u8; N] {}
29
30// Safety: An array will never change size.
31unsafe impl<const N: usize> FixedBuffer for &mut [u8; N] {}
32
33// Safety: An array will never change size.
34unsafe impl<const N: usize> FixedMutBuffer for &mut [u8; N] {}
35
36const STRIPE_BYTES: usize = 64;
37const BUFFERED_STRIPES: usize = 4;
38const BUFFERED_BYTES: usize = STRIPE_BYTES * BUFFERED_STRIPES;
39type Buffer = [u8; BUFFERED_BYTES];
40
41// Ensure that a full buffer always implies we are in the 241+ byte case.
42const _: () = assert!(BUFFERED_BYTES > CUTOFF);
43
44/// Holds secret and temporary buffers that are ensured to be
45/// appropriately sized.
46#[derive(Clone)]
47pub struct SecretBuffer<S> {
48    seed: u64,
49    secret: S,
50    buffer: Buffer,
51}
52
53impl<S> SecretBuffer<S> {
54    /// Returns the secret.
55    pub fn into_secret(self) -> S {
56        self.secret
57    }
58}
59
60impl<S> SecretBuffer<S>
61where
62    S: FixedBuffer,
63{
64    /// Takes the seed, secret, and buffer and performs no
65    /// modifications to them, only validating that the sizes are
66    /// appropriate.
67    pub fn new(seed: u64, secret: S) -> Result<Self, SecretTooShortError<S>> {
68        match Secret::new(secret.as_ref()) {
69            Ok(_) => Ok(Self {
70                seed,
71                secret,
72                buffer: [0; BUFFERED_BYTES],
73            }),
74            Err(e) => Err(SecretTooShortError(e, secret)),
75        }
76    }
77
78    #[inline(always)]
79    #[cfg(test)]
80    fn is_valid(&self) -> bool {
81        let secret = self.secret.as_ref();
82
83        secret.len() >= SECRET_MINIMUM_LENGTH
84    }
85
86    #[inline]
87    fn n_stripes(&self) -> usize {
88        Self::secret(&self.secret).n_stripes()
89    }
90
91    #[inline]
92    fn parts(&self) -> (u64, &Secret, &Buffer) {
93        (self.seed, Self::secret(&self.secret), &self.buffer)
94    }
95
96    #[inline]
97    fn parts_mut(&mut self) -> (u64, &Secret, &mut Buffer) {
98        (self.seed, Self::secret(&self.secret), &mut self.buffer)
99    }
100
101    fn secret(secret: &S) -> &Secret {
102        let secret = secret.as_ref();
103        // Safety: We established the length at construction and the
104        // length is not allowed to change.
105        unsafe { Secret::new_unchecked(secret) }
106    }
107}
108
109impl<S> SecretBuffer<S>
110where
111    S: FixedMutBuffer,
112{
113    /// Fills the secret buffer with a secret derived from the seed
114    /// and the default secret. The secret must be exactly
115    /// [`DEFAULT_SECRET_LENGTH`][] bytes long.
116    pub fn with_seed(seed: u64, mut secret: S) -> Result<Self, SecretWithSeedError<S>> {
117        match <&mut DefaultSecret>::try_from(secret.as_mut()) {
118            Ok(secret_slice) => {
119                *secret_slice = DEFAULT_SECRET_RAW;
120                derive_secret(seed, secret_slice);
121
122                Ok(Self {
123                    seed,
124                    secret,
125                    buffer: [0; BUFFERED_BYTES],
126                })
127            }
128            Err(_) => Err(SecretWithSeedError(secret)),
129        }
130    }
131}
132
133impl SecretBuffer<&'static [u8; DEFAULT_SECRET_LENGTH]> {
134    /// Use the default seed and secret values while allocating nothing.
135    #[inline]
136    pub const fn default() -> Self {
137        SecretBuffer {
138            seed: DEFAULT_SEED,
139            secret: &DEFAULT_SECRET_RAW,
140            buffer: [0; BUFFERED_BYTES],
141        }
142    }
143}
144
145#[derive(Clone)]
146pub struct RawHasherCore<S> {
147    secret_buffer: SecretBuffer<S>,
148    buffer_usage: usize,
149    stripe_accumulator: StripeAccumulator,
150    total_bytes: usize,
151}
152
153impl<S> RawHasherCore<S> {
154    pub fn new(secret_buffer: SecretBuffer<S>) -> Self {
155        Self {
156            secret_buffer,
157            buffer_usage: 0,
158            stripe_accumulator: StripeAccumulator::new(),
159            total_bytes: 0,
160        }
161    }
162
163    pub fn into_secret(self) -> S {
164        self.secret_buffer.into_secret()
165    }
166}
167
168impl<S> RawHasherCore<S>
169where
170    S: FixedBuffer,
171{
172    #[inline]
173    pub fn write(&mut self, input: &[u8]) {
174        let this = self;
175        dispatch! {
176            fn write_impl<S>(this: &mut RawHasherCore<S>, input: &[u8])
177            [S: FixedBuffer]
178        }
179    }
180
181    #[inline]
182    pub fn finish<F>(&self, finalize: F) -> F::Output
183    where
184        F: Finalize,
185    {
186        let this = self;
187        dispatch! {
188            fn finish_impl<S, F>(this: &RawHasherCore<S>, finalize: F) -> F::Output
189            [S: FixedBuffer, F: Finalize]
190        }
191    }
192}
193
194#[inline(always)]
195fn write_impl<S>(vector: impl Vector, this: &mut RawHasherCore<S>, mut input: &[u8])
196where
197    S: FixedBuffer,
198{
199    if input.is_empty() {
200        return;
201    }
202
203    let RawHasherCore {
204        secret_buffer,
205        buffer_usage,
206        stripe_accumulator,
207        total_bytes,
208        ..
209    } = this;
210
211    let n_stripes = secret_buffer.n_stripes();
212    let (_, secret, buffer) = secret_buffer.parts_mut();
213
214    *total_bytes += input.len();
215
216    // Safety: This is an invariant of the buffer.
217    unsafe {
218        debug_assert!(*buffer_usage <= buffer.len());
219        assert_unchecked(*buffer_usage <= buffer.len())
220    };
221
222    // We have some previous data saved; try to fill it up and process it first
223    if !buffer.is_empty() {
224        let remaining = &mut buffer[*buffer_usage..];
225        let n_to_copy = usize::min(remaining.len(), input.len());
226
227        let (remaining_head, remaining_tail) = remaining.split_at_mut(n_to_copy);
228        let (input_head, input_tail) = input.split_at(n_to_copy);
229
230        remaining_head.copy_from_slice(input_head);
231        *buffer_usage += n_to_copy;
232
233        input = input_tail;
234
235        // We did not fill up the buffer
236        if !remaining_tail.is_empty() {
237            return;
238        }
239
240        // We don't know this isn't the last of the data
241        if input.is_empty() {
242            return;
243        }
244
245        let (stripes, _) = buffer.bp_as_chunks();
246        for stripe in stripes {
247            stripe_accumulator.process_stripe(vector, stripe, n_stripes, secret);
248        }
249        *buffer_usage = 0;
250    }
251
252    debug_assert!(*buffer_usage == 0);
253
254    // Process as much of the input data in-place as possible,
255    // while leaving at least one full stripe for the
256    // finalization.
257    if let Some(len) = input.len().checked_sub(STRIPE_BYTES) {
258        let full_block_point = (len / STRIPE_BYTES) * STRIPE_BYTES;
259        // Safety: We know that `full_block_point` must be less than
260        // `input.len()` as we subtracted and then integer-divided
261        // (which rounds down) and then multiplied back. That's not
262        // evident to the compiler and `split_at` results in a
263        // potential panic.
264        //
265        // https://github.com/llvm/llvm-project/issues/104827
266        let (stripes, remainder) = unsafe { input.split_at_unchecked(full_block_point) };
267        let (stripes, _) = stripes.bp_as_chunks();
268
269        for stripe in stripes {
270            stripe_accumulator.process_stripe(vector, stripe, n_stripes, secret)
271        }
272        input = remainder;
273    }
274
275    // Any remaining data has to be less than the buffer, and the
276    // buffer is empty so just fill up the buffer.
277    debug_assert!(*buffer_usage == 0);
278    debug_assert!(!input.is_empty());
279
280    // Safety: We have parsed all the full blocks of input except one
281    // and potentially a full block minus one byte. That amount of
282    // data must be less than the buffer.
283    let buffer_head = unsafe {
284        debug_assert!(input.len() < 2 * STRIPE_BYTES);
285        debug_assert!(2 * STRIPE_BYTES < buffer.len());
286        buffer.get_unchecked_mut(..input.len())
287    };
288
289    buffer_head.copy_from_slice(input);
290    *buffer_usage = input.len();
291}
292
293#[inline(always)]
294fn finish_impl<S, F>(vector: impl Vector, this: &RawHasherCore<S>, finalize: F) -> F::Output
295where
296    S: FixedBuffer,
297    F: Finalize,
298{
299    let RawHasherCore {
300        ref secret_buffer,
301        buffer_usage,
302        mut stripe_accumulator,
303        total_bytes,
304    } = *this;
305
306    let n_stripes = secret_buffer.n_stripes();
307    let (seed, secret, buffer) = secret_buffer.parts();
308
309    // Safety: This is an invariant of the buffer.
310    unsafe {
311        debug_assert!(buffer_usage <= buffer.len());
312        assert_unchecked(buffer_usage <= buffer.len())
313    };
314
315    if total_bytes > CUTOFF {
316        let input = &buffer[..buffer_usage];
317
318        // Ingest final stripes
319        let (stripes, remainder) = stripes_with_tail(input);
320        for stripe in stripes {
321            stripe_accumulator.process_stripe(vector, stripe, n_stripes, secret);
322        }
323
324        let mut temp = [0; 64];
325
326        let last_stripe = match input.last_chunk() {
327            Some(chunk) => chunk,
328            None => {
329                let n_to_reuse = 64 - input.len();
330                let to_reuse = buffer.len() - n_to_reuse;
331
332                let (temp_head, temp_tail) = temp.split_at_mut(n_to_reuse);
333                temp_head.copy_from_slice(&buffer[to_reuse..]);
334                temp_tail.copy_from_slice(input);
335
336                &temp
337            }
338        };
339
340        finalize.large(
341            vector,
342            stripe_accumulator.accumulator,
343            remainder,
344            last_stripe,
345            secret,
346            total_bytes,
347        )
348    } else {
349        finalize.small(DEFAULT_SECRET, seed, &buffer[..total_bytes])
350    }
351}
352
353pub trait Finalize {
354    type Output;
355
356    fn small(&self, secret: &Secret, seed: u64, input: &[u8]) -> Self::Output;
357
358    fn large(
359        &self,
360        vector: impl Vector,
361        acc: [u64; 8],
362        last_block: &[u8],
363        last_stripe: &[u8; 64],
364        secret: &Secret,
365        len: usize,
366    ) -> Self::Output;
367}
368
369#[cfg(feature = "alloc")]
370#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
371pub mod with_alloc {
372    use ::alloc::boxed::Box;
373
374    use super::*;
375
376    // Safety: A plain slice will never change size.
377    unsafe impl FixedBuffer for Box<[u8]> {}
378
379    // Safety: A plain slice will never change size.
380    unsafe impl FixedMutBuffer for Box<[u8]> {}
381
382    type AllocSecretBuffer = SecretBuffer<Box<[u8]>>;
383
384    impl AllocSecretBuffer {
385        /// Allocates the secret and temporary buffers and fills them
386        /// with the default seed and secret values.
387        pub fn allocate_default() -> Self {
388            Self {
389                seed: DEFAULT_SEED,
390                secret: DEFAULT_SECRET_RAW.to_vec().into(),
391                buffer: [0; BUFFERED_BYTES],
392            }
393        }
394
395        /// Allocates the secret and temporary buffers and uses the
396        /// provided seed to construct the secret value.
397        pub fn allocate_with_seed(seed: u64) -> Self {
398            let mut secret = DEFAULT_SECRET_RAW;
399            derive_secret(seed, &mut secret);
400
401            Self {
402                seed,
403                secret: secret.to_vec().into(),
404                buffer: [0; BUFFERED_BYTES],
405            }
406        }
407
408        /// Allocates the temporary buffer and uses the provided seed
409        /// and secret buffer.
410        pub fn allocate_with_seed_and_secret(
411            seed: u64,
412            secret: impl Into<Box<[u8]>>,
413        ) -> Result<Self, SecretTooShortError<Box<[u8]>>> {
414            Self::new(seed, secret.into())
415        }
416    }
417
418    pub type AllocRawHasher = RawHasherCore<Box<[u8]>>;
419
420    impl AllocRawHasher {
421        pub fn allocate_default() -> Self {
422            Self::new(SecretBuffer::allocate_default())
423        }
424
425        pub fn allocate_with_seed(seed: u64) -> Self {
426            Self::new(SecretBuffer::allocate_with_seed(seed))
427        }
428
429        pub fn allocate_with_seed_and_secret(
430            seed: u64,
431            secret: impl Into<Box<[u8]>>,
432        ) -> Result<Self, SecretTooShortError<Box<[u8]>>> {
433            SecretBuffer::allocate_with_seed_and_secret(seed, secret).map(Self::new)
434        }
435    }
436}
437
438#[cfg(feature = "alloc")]
439pub use with_alloc::AllocRawHasher;
440
441/// Tracks which stripe we are currently on to know which part of the
442/// secret we should be using.
443#[derive(Copy, Clone)]
444pub struct StripeAccumulator {
445    pub accumulator: [u64; 8],
446    current_stripe: usize,
447}
448
449impl StripeAccumulator {
450    pub fn new() -> Self {
451        Self {
452            accumulator: INITIAL_ACCUMULATORS,
453            current_stripe: 0,
454        }
455    }
456
457    #[inline]
458    pub fn process_stripe(
459        &mut self,
460        vector: impl Vector,
461        stripe: &[u8; 64],
462        n_stripes: usize,
463        secret: &Secret,
464    ) {
465        let Self {
466            accumulator,
467            current_stripe,
468            ..
469        } = self;
470
471        // For each stripe
472
473        // Safety: The number of stripes is determined by the
474        // block size, which is determined by the secret size.
475        let secret_stripe = unsafe { secret.stripe(*current_stripe) };
476        vector.accumulate(accumulator, stripe, secret_stripe);
477
478        *current_stripe += 1;
479
480        // After a full block's worth
481        if *current_stripe == n_stripes {
482            let secret_end = secret.last_stripe();
483            vector.round_scramble(accumulator, secret_end);
484
485            *current_stripe = 0;
486        }
487    }
488}
489
490/// The provided secret was not exactly [`DEFAULT_SECRET_LENGTH`][]
491/// bytes.
492pub struct SecretWithSeedError<S>(S);
493
494impl<S> SecretWithSeedError<S> {
495    /// Returns the secret.
496    pub fn into_secret(self) -> S {
497        self.0
498    }
499}
500
501impl<S> core::error::Error for SecretWithSeedError<S> {}
502
503impl<S> core::fmt::Debug for SecretWithSeedError<S> {
504    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
505        f.debug_tuple("SecretWithSeedError").finish()
506    }
507}
508
509impl<S> core::fmt::Display for SecretWithSeedError<S> {
510    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
511        write!(
512            f,
513            "The secret must be exactly {DEFAULT_SECRET_LENGTH} bytes"
514        )
515    }
516}
517
518/// The provided secret was not at least [`SECRET_MINIMUM_LENGTH`][]
519/// bytes.
520pub struct SecretTooShortError<S>(secret::Error, S);
521
522impl<S> SecretTooShortError<S> {
523    /// Returns the secret.
524    pub fn into_secret(self) -> S {
525        self.1
526    }
527}
528
529impl<S> core::error::Error for SecretTooShortError<S> {}
530
531impl<S> core::fmt::Debug for SecretTooShortError<S> {
532    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
533        f.debug_tuple("SecretTooShortError").finish()
534    }
535}
536
537impl<S> core::fmt::Display for SecretTooShortError<S> {
538    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
539        self.0.fmt(f)
540    }
541}
542
543#[cfg(test)]
544mod test {
545    use super::*;
546
547    #[test]
548    fn secret_buffer_default_is_valid() {
549        assert!(SecretBuffer::default().is_valid());
550    }
551
552    #[test]
553    fn secret_buffer_allocate_default_is_valid() {
554        assert!(SecretBuffer::allocate_default().is_valid())
555    }
556
557    #[test]
558    fn secret_buffer_allocate_with_seed_is_valid() {
559        assert!(SecretBuffer::allocate_with_seed(0xdead_beef).is_valid())
560    }
561}