sha3_utils/
bytepad.rs

1#![allow(
2    clippy::indexing_slicing,
3    reason = "The compiler can prove that the indices are in bounds"
4)]
5#![allow(
6    clippy::arithmetic_side_effects,
7    reason = "All arithmetic is in bounds"
8)]
9
10use core::{
11    cmp,
12    iter::{self, FusedIterator},
13    ops::{Add, AddAssign},
14};
15
16#[cfg(feature = "no-panic")]
17use no_panic::no_panic;
18
19use crate::{
20    enc::{left_encode, EncodedString, LeftEncode, LeftEncodeBytes, USIZE_BYTES},
21    util::as_chunks,
22};
23
24/// The minimum size, in bytes, allowed by [`bytepad_blocks`].
25pub const MIN_BLOCK_SIZE: usize = (1 + USIZE_BYTES) * 2;
26
27/// Same as [`bytepad`], but returns the data as blocks of length
28/// `W`.
29///
30/// In practice, this has helped avoid needless calls to `memcpy`
31/// and has helped remove panicking branches.
32pub fn bytepad_blocks<const W: usize>(
33    s: EncodedString<'_>,
34) -> ([u8; W], &[[u8; W]], Option<[u8; W]>) {
35    const { assert!(W >= MIN_BLOCK_SIZE, "`W` is too small") }
36
37    let (prefix, s) = s.to_parts();
38
39    // `first` is left_encode(w) || left_encode(s) || s[..n].
40    let (first, s) = {
41        let mut first = [0u8; W];
42        let mut i = 0;
43
44        #[inline(always)]
45        fn copy(dst: &mut [u8], src: &[u8]) -> usize {
46            let n = cmp::min(dst.len(), src.len());
47            dst[..n].copy_from_slice(&src[..n]);
48            n
49        }
50
51        // This copy cannot panic because W >= (1+USIZE_BYTES)*2
52        // and `w` is at most 1+USIZE_BYTES bytes long.
53        let w = left_encode(W);
54        copy(&mut first[i..], w.as_fixed_bytes());
55        i += w.len();
56
57        // Try and copy over the prefix. This copy cannot panic
58        // because W >= (1*USIZE_BYTES)*2 and `i` is at most
59        // 1+USIZE_BYTES.
60        copy(&mut first[i..], prefix.as_fixed_bytes());
61        i += prefix.len();
62
63        // Fill the remainder of the block with `s`.
64        let n = copy(&mut first[i..], s);
65        (first, &s[n..])
66    };
67
68    // `mid` is s[n..m].
69    let (mid, rest) = as_chunks(s);
70
71    // `last` is s[..m].
72    let last = if !rest.is_empty() {
73        let mut block = [0u8; W];
74        block[..rest.len()].copy_from_slice(rest);
75        Some(block)
76    } else {
77        None
78    };
79
80    (first, mid, last)
81}
82
83/// Prepends the integer encoding of `W` to `s`, then pads the
84/// result to a multiple of `W` for a non-zero `W`.
85///
86/// # Example
87///
88/// ```rust
89/// use sha3_utils::{bytepad, encode_string};
90///
91/// let v = bytepad::<32, _>([encode_string(b"hello, world!")]);
92/// assert_eq!(
93///     v.flat_map(|v| v.as_bytes().to_vec()).collect::<Vec<_>>(),
94///     &[
95///         1, 32,
96///         1, 104,
97///         104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33,
98///         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
99///     ],
100/// );
101///
102/// let v = bytepad::<32, _>([
103///     encode_string(b"hello, world!"),
104///     encode_string(b"hello, world!"),
105/// ]);
106/// assert_eq!(
107///     v.flat_map(|v| v.as_bytes().to_vec()).collect::<Vec<_>>(),
108///     &[
109///         1, 32,
110///         1, 104,
111///         104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33,
112///         1, 104,
113///         104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33,
114///     ],
115/// );
116/// ```
117#[inline]
118#[cfg_attr(feature = "no-panic", no_panic)]
119pub fn bytepad<'a, const W: usize, I>(s: I) -> BytePad<'a, W, <I as IntoIterator>::IntoIter>
120where
121    I: IntoIterator<Item = EncodedString<'a>>,
122{
123    const { assert!(W > 0) }
124
125    BytePad {
126        w: iter::once(BytePadItem::w(left_encode(W))),
127        x: FlatEncStrs {
128            iter: s.into_iter(),
129            s: None,
130        },
131        pad: Pad::new(left_encode(W).len()),
132        done: false,
133    }
134}
135
136/// The result of [`bytepad`].
137#[derive(Clone, Debug)]
138pub struct BytePad<'a, const W: usize, I> {
139    // `left_encode(W)`
140    w: iter::Once<BytePadItem<'static>>,
141    // The encoded input strings, X.
142    x: FlatEncStrs<'a, I>,
143    // Current padding.
144    pad: Pad<W>,
145    // True after having returned padding.
146    done: bool,
147}
148
149impl<'a, const W: usize, I> Iterator for BytePad<'a, W, I>
150where
151    I: Iterator<Item = EncodedString<'a>>,
152{
153    type Item = BytePadItem<'a>;
154
155    #[inline]
156    #[cfg_attr(feature = "no-panic", no_panic)]
157    fn next(&mut self) -> Option<Self::Item> {
158        if let Some(w) = self.w.next() {
159            return Some(w);
160        }
161        if let Some(v) = self.x.next() {
162            let item = v.into_item();
163            self.pad += item.len();
164            return Some(item);
165        }
166        if !self.done {
167            self.done = true;
168            let pad = self.pad.to_remainder();
169            if !pad.is_empty() {
170                return Some(BytePadItem::pad(pad));
171            }
172        }
173        None
174    }
175
176    #[cfg_attr(feature = "no-panic", no_panic)]
177    fn fold<B, F>(mut self, init: B, mut f: F) -> B
178    where
179        F: FnMut(B, Self::Item) -> B,
180    {
181        let mut accum = init;
182        if let Some(w) = self.w.next() {
183            //accum = f(accum, BytePadItem::w(w));
184            accum = f(accum, w);
185        }
186        for v in self.x {
187            let item = v.into_item();
188            self.pad += item.len();
189            accum = f(accum, item);
190        }
191        if !self.done {
192            self.done = true;
193            let pad = self.pad.to_remainder();
194            if !pad.is_empty() {
195                accum = f(accum, BytePadItem::pad(pad));
196            }
197        }
198        accum
199    }
200}
201
202impl<'a, const W: usize, I> FusedIterator for BytePad<'a, W, I> where
203    I: FusedIterator<Item = EncodedString<'a>>
204{
205}
206
207/// An iterator that flattens [`EncodedString`]s into their
208/// parts.
209#[derive(Clone, Debug)]
210struct FlatEncStrs<'a, I> {
211    iter: I,
212    /// The string half of the current [`EncodedString`].
213    s: Option<&'a [u8]>,
214}
215
216impl<'a, I> Iterator for FlatEncStrs<'a, I>
217where
218    I: Iterator<Item = EncodedString<'a>>,
219{
220    type Item = EncStrPart<'a>;
221
222    #[inline]
223    #[cfg_attr(feature = "no-panic", no_panic)]
224    fn next(&mut self) -> Option<Self::Item> {
225        if let Some(s) = self.s.take() {
226            return Some(EncStrPart::S(s));
227        }
228        let v = self.iter.next()?;
229        let (p, s) = v.to_parts();
230        self.s = Some(s);
231        Some(EncStrPart::P(p))
232    }
233}
234
235impl<'a, I> FusedIterator for FlatEncStrs<'a, I> where I: FusedIterator<Item = EncodedString<'a>> {}
236
237/// Half of a [`EncodedString`].
238#[derive(Clone, Debug)]
239enum EncStrPart<'a> {
240    /// The prefix.
241    P(LeftEncodeBytes),
242    /// The string data.
243    S(&'a [u8]),
244}
245
246impl<'a> EncStrPart<'a> {
247    fn into_item(self) -> BytePadItem<'a> {
248        match self {
249            EncStrPart::P(p) => BytePadItem::p(p),
250            EncStrPart::S(s) => BytePadItem::s(s),
251        }
252    }
253}
254
255/// An item from [`BytePad`].
256#[derive(Copy, Clone, Debug)]
257pub struct BytePadItem<'a>(BytePadItemRepr<'a>);
258
259impl<'a> BytePadItem<'a> {
260    #[inline]
261    const fn len(&self) -> usize {
262        match &self.0 {
263            BytePadItemRepr::W(v) => v.len(),
264            BytePadItemRepr::P(v) => v.len(),
265            BytePadItemRepr::S(v) => v.len(),
266            BytePadItemRepr::Pad(v) => v.len(),
267        }
268    }
269
270    /// Returns the byte representation of this item.
271    #[inline]
272    pub const fn as_bytes(&self) -> &[u8] {
273        match &self.0 {
274            BytePadItemRepr::W(v) => v.as_bytes(),
275            BytePadItemRepr::P(v) => v.as_bytes(),
276            BytePadItemRepr::S(v) => v,
277            BytePadItemRepr::Pad(v) => v,
278        }
279    }
280
281    const fn w(v: LeftEncode) -> Self {
282        Self(BytePadItemRepr::W(v))
283    }
284    const fn p(v: LeftEncodeBytes) -> Self {
285        Self(BytePadItemRepr::P(v))
286    }
287    const fn s(v: &'a [u8]) -> Self {
288        Self(BytePadItemRepr::S(v))
289    }
290    const fn pad(v: &'static [u8]) -> Self {
291        Self(BytePadItemRepr::Pad(v))
292    }
293}
294
295impl AsRef<[u8]> for BytePadItem<'_> {
296    #[inline]
297    fn as_ref(&self) -> &[u8] {
298        self.as_bytes()
299    }
300}
301
302#[derive(Copy, Clone, Debug)]
303enum BytePadItemRepr<'a> {
304    W(LeftEncode),
305    P(LeftEncodeBytes),
306    S(&'a [u8]),
307    Pad(&'static [u8]),
308}
309
310/// Padding modulo `W`.
311#[derive(Copy, Clone, Debug)]
312struct Pad<const W: usize>(usize);
313
314impl<const W: usize> Pad<W> {
315    const PAD: &[u8] = &[0u8; W];
316
317    const fn new(n: usize) -> Self {
318        const { assert!(W > 0) }
319
320        Self(n % W)
321    }
322
323    fn to_remainder(self) -> &'static [u8] {
324        const { assert!(W > 0) }
325
326        if self.0 != 0 {
327            &Self::PAD[..W - (self.0 % W)]
328        } else {
329            &[]
330        }
331    }
332}
333
334impl<const W: usize> Add for Pad<W> {
335    type Output = Self;
336
337    #[inline(always)]
338    fn add(self, rhs: Self) -> Self::Output {
339        const { assert!(W > 0) }
340
341        let a = self.0;
342        let b = rhs.0;
343        Self((a + b) % W)
344    }
345}
346impl<const W: usize> Add<usize> for Pad<W> {
347    type Output = Self;
348
349    #[inline(always)]
350    fn add(self, rhs: usize) -> Self::Output {
351        const { assert!(W > 0) }
352
353        let a = self.0;
354        let b = rhs % W;
355        Self((a + b) % W)
356    }
357}
358impl<const W: usize> AddAssign<usize> for Pad<W> {
359    #[inline(always)]
360    fn add_assign(&mut self, rhs: usize) {
361        const { assert!(W > 0) }
362
363        self.0 += rhs % W;
364        self.0 %= W;
365    }
366}
367impl<const W: usize> PartialEq<usize> for Pad<W> {
368    #[inline(always)]
369    fn eq(&self, other: &usize) -> bool {
370        self.0 == *other
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::enc::encode_string;
378
379    #[test]
380    fn test_bytepad() {
381        #[rustfmt::skip]
382        let want = &[
383            1, 32, 
384            1, 104, 
385            104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 
386            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
387        ];
388        let got = bytepad::<32, _>([encode_string(b"hello, world!")])
389            .flat_map(|v| v.as_bytes().to_vec())
390            .collect::<Vec<_>>();
391        assert_eq!(got, want);
392
393        #[rustfmt::skip]
394        let want = &[
395            1, 32, 
396            1, 104, 
397            104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 
398            1, 104, 
399            104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 
400        ];
401        let got = bytepad::<32, _>([
402            encode_string(b"hello, world!"),
403            encode_string(b"hello, world!"),
404        ])
405        .flat_map(|v| v.as_bytes().to_vec())
406        .collect::<Vec<_>>();
407        assert_eq!(got, want);
408    }
409
410    #[test]
411    fn test_bytepad_blocks() {
412        #[rustfmt::skip]
413        let want = &[
414            1, 32, 
415            1, 104, 
416            104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 
417            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
418        ];
419
420        let (a, b, c) = bytepad_blocks::<32>(encode_string(b"hello, world!"));
421        let mut got = Vec::new();
422        got.extend(a);
423        for block in b {
424            got.extend(block);
425        }
426        if let Some(c) = c {
427            got.extend(c);
428        }
429        assert_eq!(got, want);
430    }
431}