Skip to main content

sponge_cursor/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/media/6ee8e381/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/media/6ee8e381/logo.svg"
6)]
7
8mod u64_le_utils;
9
10/// Cursor for implementing sponge-based absorption and squeezing.
11///
12/// This type wraps `u8` and enforces that its value is always smaller than `RATE`.
13///
14/// `RATE` MUST be smaller than `256`, trying to initialize cursor with an invalid rate will
15/// result in a compilation error.
16#[derive(Debug, Clone, Eq, PartialEq, Hash)]
17pub struct SpongeCursor<const RATE: usize> {
18    pos: u8,
19}
20
21impl<const RATE: usize> Default for SpongeCursor<RATE> {
22    fn default() -> Self {
23        const {
24            assert!(RATE != 0);
25            assert!(RATE < u8::MAX as usize);
26        }
27
28        Self { pos: 0 }
29    }
30}
31
32// Note that the methods should compile into a panic-free code,
33// see: https://rust.godbolt.org/z/r93WE8zq3
34impl<const RATE: usize> SpongeCursor<RATE> {
35    /// Create new cursor with the provided position.
36    ///
37    /// Returns `None` if `pos` is bigger or equal to `RATE`.
38    #[must_use]
39    pub fn new(pos: u8) -> Option<Self> {
40        if usize::from(pos) < RATE {
41            Some(Self { pos })
42        } else {
43            None
44        }
45    }
46
47    /// Get current cursor position as `u8`.
48    #[must_use]
49    #[inline(always)]
50    #[allow(clippy::missing_panics_doc, reason = "the method is panic-free")]
51    pub fn raw_pos(&self) -> u8 {
52        let rate_u8 = u8::try_from(RATE).expect("RATE is smaller than 256");
53        debug_assert!(self.pos < rate_u8);
54        if self.pos < rate_u8 {
55            self.pos
56        } else {
57            // SAFETY: the type enforces that `pos` is always smaller than `RATE`
58            unsafe { core::hint::unreachable_unchecked() };
59        }
60    }
61
62    /// Get current cursor position as `usize`.
63    #[must_use]
64    #[inline(always)]
65    pub fn pos(&self) -> usize {
66        let pos = usize::from(self.pos);
67        debug_assert!(pos < RATE);
68        if pos < RATE {
69            pos
70        } else {
71            // SAFETY: the type enforces that `pos` is always smaller than `RATE`
72            unsafe { core::hint::unreachable_unchecked() };
73        }
74    }
75
76    /// Set new cursor position.
77    ///
78    /// # Panics
79    /// If `new_pos` is greater or equal to `RATE`.
80    fn set_pos(&mut self, new_pos: usize) {
81        assert!(new_pos < RATE);
82        self.pos = u8::try_from(new_pos).expect("`new_pos` is smaller than `RATE`");
83    }
84
85    /// Absorb bytes from `data` into `state` using little-endian byte order.
86    ///
87    /// Size of `state` in bytes MUST be greater or equal to `RATE`.
88    /// Using an invalid `N` will result in a compilation error.
89    #[allow(clippy::missing_panics_doc, reason = "the method is panic-free")]
90    #[inline]
91    pub fn absorb_u64_le<const N: usize>(
92        &mut self,
93        state: &mut [u64; N],
94        sponge: fn(&mut [u64; N]),
95        mut data: &[u8],
96    ) {
97        const {
98            assert!(RATE <= size_of::<[u64; N]>());
99            assert!(RATE < u8::MAX as usize);
100            assert!(RATE % size_of::<u64>() == 0);
101        };
102
103        if self.pos != 0 {
104            let pos = self.pos();
105            let rem_len = RATE
106                .checked_sub(pos)
107                .expect("`pos` is always smaller than `RATE`");
108
109            if data.len() < rem_len {
110                u64_le_utils::absorb_partial::<N, RATE>(state, pos, data);
111                self.set_pos(pos + data.len());
112                return;
113            }
114
115            let (head, tail) = data.split_at(rem_len);
116            data = tail;
117            u64_le_utils::absorb_partial::<N, RATE>(state, pos, head);
118
119            sponge(state);
120        }
121
122        let blocks = data.chunks_exact(RATE);
123        let tail = blocks.remainder();
124
125        for block in blocks {
126            let block: &[u8; RATE] = block.try_into().expect("`block` has correct size");
127            u64_le_utils::absorb_full(state, block);
128            sponge(state);
129        }
130
131        if !tail.is_empty() {
132            u64_le_utils::absorb_partial::<N, RATE>(state, 0, tail);
133        }
134
135        self.set_pos(tail.len());
136    }
137
138    /// Squeeze data from `state` by reading it into `buf` using little-endian byte order.
139    ///
140    /// Size of `state` in bytes MUST be greater or equal to `RATE`.
141    /// Using an invalid `N` will result in a compilation error.
142    #[inline]
143    pub fn squeeze_read_u64_le<const N: usize>(
144        &mut self,
145        state: &mut [u64; N],
146        sponge: fn(&mut [u64; N]),
147        buf: &mut [u8],
148    ) {
149        self.squeeze_inner_u64_le(
150            state,
151            sponge,
152            buf,
153            u64_le_utils::squeeze_read_partial::<N, RATE>,
154            u64_le_utils::squeeze_read_full,
155        );
156    }
157
158    /// Squeeze data from `state` by XOR-ing it with data in `buf` using little-endian byte order.
159    ///
160    /// Size of `state` in bytes MUST be greater or equal to `RATE`.
161    /// Using an invalid `N` will result in a compilation error.
162    #[inline]
163    pub fn squeeze_xor_u64_le<const N: usize>(
164        &mut self,
165        state: &mut [u64; N],
166        sponge: fn(&mut [u64; N]),
167        buf: &mut [u8],
168    ) {
169        self.squeeze_inner_u64_le(
170            state,
171            sponge,
172            buf,
173            u64_le_utils::squeeze_xor_partial::<N, RATE>,
174            u64_le_utils::squeeze_xor_full,
175        );
176    }
177
178    /// Squeeze data by calling custom functions using little-endian byte order.
179    #[inline(always)]
180    fn squeeze_inner_u64_le<const N: usize>(
181        &mut self,
182        state: &mut [u64; N],
183        sponge: fn(&mut [u64; N]),
184        mut buf: &mut [u8],
185        process_partial: fn(&[u64; N], usize, &mut [u8]),
186        process_full: fn(&[u64; N], &mut [u8; RATE]),
187    ) {
188        const {
189            assert!(RATE <= size_of::<[u64; N]>());
190            assert!(RATE < u8::MAX as usize);
191            assert!(RATE % size_of::<u64>() == 0);
192        };
193
194        if self.pos != 0 {
195            let pos = self.pos();
196            let rem_len = RATE - pos;
197
198            if buf.len() < rem_len {
199                process_partial(state, pos, buf);
200                self.set_pos(pos + buf.len());
201                return;
202            }
203
204            let (head, tail) = buf.split_at_mut(rem_len);
205            buf = tail;
206
207            process_partial(state, pos, head);
208        }
209
210        let mut blocks = buf.chunks_exact_mut(RATE);
211
212        for block in &mut blocks {
213            sponge(state);
214            let block = block.try_into().expect("`block` has correct size");
215            process_full(state, block);
216        }
217
218        let tail = blocks.into_remainder();
219
220        if !tail.is_empty() {
221            sponge(state);
222            process_partial(state, 0, tail);
223        }
224
225        self.set_pos(tail.len());
226    }
227}
228
229#[cfg(feature = "zeroize")]
230impl<const RATE: usize> zeroize::Zeroize for SpongeCursor<RATE> {
231    fn zeroize(&mut self) {
232        self.pos.zeroize();
233    }
234}