Skip to main content

seekable_stream_cipher/
chacha.rs

1use core::cmp;
2
3/// An ChaCha-based seekable stream cipher.
4#[derive(Clone, Copy)]
5pub struct StreamCipher {
6    /// The ChaCha state
7    st: [u32; 16],
8}
9
10impl StreamCipher {
11    /// The key length in bytes
12    pub const KEY_LENGTH: usize = 32;
13
14    /// The ChaCha constants
15    const CONSTANTS: [u32; 4] = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574];
16
17    /// Create a new state with the given key and context.
18    ///
19    /// The key must be 32 bytes long, and must be randomly generated, for example using
20    /// `rand::thread_rng().gen::<[u8; 32]>()` or `getrandom::fill()`.
21    ///
22    /// The context identifier is used to improve multi-user security.
23    pub fn new(key: &[u8; Self::KEY_LENGTH], id: &[u8; 8]) -> Self {
24        let st = [
25            Self::CONSTANTS[0],
26            Self::CONSTANTS[1],
27            Self::CONSTANTS[2],
28            Self::CONSTANTS[3],
29            u32::from_le_bytes(key[0..4].try_into().unwrap()),
30            u32::from_le_bytes(key[4..8].try_into().unwrap()),
31            u32::from_le_bytes(key[8..12].try_into().unwrap()),
32            u32::from_le_bytes(key[12..16].try_into().unwrap()),
33            u32::from_le_bytes(key[16..20].try_into().unwrap()),
34            u32::from_le_bytes(key[20..24].try_into().unwrap()),
35            u32::from_le_bytes(key[24..28].try_into().unwrap()),
36            u32::from_le_bytes(key[28..32].try_into().unwrap()),
37            0,
38            0,
39            u32::from_le_bytes(id[0..4].try_into().unwrap()),
40            u32::from_le_bytes(id[4..8].try_into().unwrap()),
41        ];
42        StreamCipher { st }
43    }
44
45    /// Squeeze a 32-byte block, and store it in the given buffer.
46    #[inline(always)]
47    fn store_rate(mut self, out: &mut [u8], block_offset: u64) {
48        self.st[12] = block_offset as _;
49        self.st[13] = (block_offset >> 32) as _;
50        self.permute();
51        for i in 0..16 {
52            out[i * 4..][0..4].copy_from_slice(&(self.st[i]).to_le_bytes());
53        }
54    }
55
56    /// Squeeze a 32-byte block, and add it to the given buffer.
57    #[inline(always)]
58    fn apply_rate(mut self, out: &mut [u8], block_offset: u64) {
59        self.st[12] = block_offset as _;
60        self.st[13] = (block_offset >> 32) as _;
61        self.permute();
62
63        let out0 = u32::from_le_bytes(out[0 * 4..][0..4].try_into().unwrap());
64        let out1 = u32::from_le_bytes(out[1 * 4..][0..4].try_into().unwrap());
65        let out2 = u32::from_le_bytes(out[2 * 4..][0..4].try_into().unwrap());
66        let out3 = u32::from_le_bytes(out[3 * 4..][0..4].try_into().unwrap());
67        let out4 = u32::from_le_bytes(out[4 * 4..][0..4].try_into().unwrap());
68        let out5 = u32::from_le_bytes(out[5 * 4..][0..4].try_into().unwrap());
69        let out6 = u32::from_le_bytes(out[6 * 4..][0..4].try_into().unwrap());
70        let out7 = u32::from_le_bytes(out[7 * 4..][0..4].try_into().unwrap());
71        let out8 = u32::from_le_bytes(out[8 * 4..][0..4].try_into().unwrap());
72        let out9 = u32::from_le_bytes(out[9 * 4..][0..4].try_into().unwrap());
73        let out10 = u32::from_le_bytes(out[10 * 4..][0..4].try_into().unwrap());
74        let out11 = u32::from_le_bytes(out[11 * 4..][0..4].try_into().unwrap());
75        let out12 = u32::from_le_bytes(out[12 * 4..][0..4].try_into().unwrap());
76        let out13 = u32::from_le_bytes(out[13 * 4..][0..4].try_into().unwrap());
77        let out14 = u32::from_le_bytes(out[14 * 4..][0..4].try_into().unwrap());
78        let out15 = u32::from_le_bytes(out[15 * 4..][0..4].try_into().unwrap());
79        out[0 * 4..][0..4].copy_from_slice(&(out0 ^ self.st[0]).to_le_bytes());
80        out[1 * 4..][0..4].copy_from_slice(&(out1 ^ self.st[1]).to_le_bytes());
81        out[2 * 4..][0..4].copy_from_slice(&(out2 ^ self.st[2]).to_le_bytes());
82        out[3 * 4..][0..4].copy_from_slice(&(out3 ^ self.st[3]).to_le_bytes());
83        out[4 * 4..][0..4].copy_from_slice(&(out4 ^ self.st[4]).to_le_bytes());
84        out[5 * 4..][0..4].copy_from_slice(&(out5 ^ self.st[5]).to_le_bytes());
85        out[6 * 4..][0..4].copy_from_slice(&(out6 ^ self.st[6]).to_le_bytes());
86        out[7 * 4..][0..4].copy_from_slice(&(out7 ^ self.st[7]).to_le_bytes());
87        out[8 * 4..][0..4].copy_from_slice(&(out8 ^ self.st[8]).to_le_bytes());
88        out[9 * 4..][0..4].copy_from_slice(&(out9 ^ self.st[9]).to_le_bytes());
89        out[10 * 4..][0..4].copy_from_slice(&(out10 ^ self.st[10]).to_le_bytes());
90        out[11 * 4..][0..4].copy_from_slice(&(out11 ^ self.st[11]).to_le_bytes());
91        out[12 * 4..][0..4].copy_from_slice(&(out12 ^ self.st[12]).to_le_bytes());
92        out[13 * 4..][0..4].copy_from_slice(&(out13 ^ self.st[13]).to_le_bytes());
93        out[14 * 4..][0..4].copy_from_slice(&(out14 ^ self.st[14]).to_le_bytes());
94        out[15 * 4..][0..4].copy_from_slice(&(out15 ^ self.st[15]).to_le_bytes());
95    }
96
97    /// Squeeze and return a 64-byte block.
98    #[inline(always)]
99    fn squeeze_rate(self, block_offset: u64) -> [u8; 64] {
100        let mut out = [0u8; 64];
101        self.store_rate(&mut out, block_offset);
102        out
103    }
104
105    /// Fill the given buffer with the keystream starting at the given offset.
106    ///
107    /// The offset is in bytes.
108    ///
109    /// The key stream is deterministic: the same key, context and offset will always produce the same output.
110    pub fn fill(&self, mut out: &mut [u8], start_offset: u64) -> Result<(), &'static str> {
111        if start_offset.checked_add(out.len() as u64).is_none() {
112            return Err("offset would overflow");
113        }
114        let mut block_offset = start_offset / 64;
115        let offset_in_first_block = (start_offset % 64) as usize;
116        if offset_in_first_block != 0 {
117            let bytes_to_copy = cmp::min(64 - offset_in_first_block, out.len());
118            if bytes_to_copy > 0 {
119                let rate = self.squeeze_rate(block_offset);
120                out[..bytes_to_copy]
121                    .copy_from_slice(&rate[offset_in_first_block..][..bytes_to_copy]);
122                out = &mut out[bytes_to_copy..];
123            }
124            block_offset += 1;
125        }
126        while out.len() >= 64 {
127            self.store_rate(&mut out[..64], block_offset);
128            block_offset += 1;
129            out = &mut out[64..];
130        }
131        if !out.is_empty() {
132            let rate = self.squeeze_rate(block_offset);
133            out.copy_from_slice(&rate[..out.len()]);
134        }
135        Ok(())
136    }
137
138    /// Encrypt or decrypt the given buffer in place, given the offset.
139    ///
140    /// The buffer is modified in place.
141    /// The offset is in bytes.
142    ///
143    /// The key stream is deterministic: the same key, context and offset will always produce the same output.
144    /// This function is equivalent to calling `fill` and then XORing the output with the input.
145    ///
146    /// # Caveats
147    ///
148    /// * There is no integrity.
149    /// * An adversary can flip arbitrary bits in the ciphertext and the corresponding bits in the plaintext will be flipped when decrypted.
150    pub fn apply_keystream(
151        &self,
152        mut out: &mut [u8],
153        start_offset: u64,
154    ) -> Result<(), &'static str> {
155        if start_offset.checked_add(out.len() as u64).is_none() {
156            return Err("offset would overflow");
157        }
158        let mut block_offset = start_offset / 64;
159        let offset_in_first_block = (start_offset % 64) as usize;
160        if offset_in_first_block != 0 {
161            let bytes_to_copy = cmp::min(64 - offset_in_first_block, out.len());
162            if bytes_to_copy > 0 {
163                let rate = self.squeeze_rate(block_offset);
164                for i in 0..bytes_to_copy {
165                    out[i] ^= rate[offset_in_first_block + i];
166                }
167                out = &mut out[bytes_to_copy..];
168            }
169            block_offset += 1;
170        }
171        while out.len() >= 64 {
172            self.apply_rate(&mut out[..64], block_offset);
173            block_offset += 1;
174            out = &mut out[64..];
175        }
176        if !out.is_empty() {
177            let rate = self.squeeze_rate(block_offset);
178            for i in 0..out.len() {
179                out[i] ^= rate[i];
180            }
181        }
182        Ok(())
183    }
184
185    fn permute(&mut self) {
186        let mask: [u32; 12] = self.st[4..].try_into().unwrap();
187        let x = &mut self.st;
188        for _ in 0..12 / 2 {
189            {
190                const R: [usize; 4] = [0, 4, 8, 12];
191                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
192                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
193                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
194                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
195                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
196                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
197                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
198                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
199            }
200            {
201                const R: [usize; 4] = [1, 5, 9, 13];
202                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
203                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
204                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
205                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
206                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
207                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
208                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
209                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
210            }
211            {
212                const R: [usize; 4] = [2, 6, 10, 14];
213                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
214                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
215                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
216                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
217                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
218                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
219                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
220                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
221            }
222            {
223                const R: [usize; 4] = [3, 7, 11, 15];
224                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
225                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
226                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
227                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
228                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
229                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
230                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
231                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
232            }
233            {
234                const R: [usize; 4] = [0, 5, 10, 15];
235                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
236                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
237                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
238                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
239                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
240                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
241                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
242                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
243            }
244            {
245                const R: [usize; 4] = [1, 6, 11, 12];
246                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
247                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
248                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
249                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
250                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
251                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
252                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
253                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
254            }
255            {
256                const R: [usize; 4] = [2, 7, 8, 13];
257                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
258                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
259                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
260                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
261                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
262                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
263                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
264                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
265            }
266            {
267                const R: [usize; 4] = [3, 4, 9, 14];
268                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
269                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(16);
270                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
271                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(12);
272                x[R[0]] = x[R[0]].wrapping_add(x[R[1]]);
273                x[R[3]] = (x[R[3]] ^ x[R[0]]).rotate_left(8);
274                x[R[2]] = x[R[2]].wrapping_add(x[R[3]]);
275                x[R[1]] = (x[R[1]] ^ x[R[2]]).rotate_left(7);
276            }
277        }
278        x[0] = x[0].wrapping_add(Self::CONSTANTS[0]);
279        x[1] = x[1].wrapping_add(Self::CONSTANTS[1]);
280        x[2] = x[2].wrapping_add(Self::CONSTANTS[2]);
281        x[3] = x[3].wrapping_add(Self::CONSTANTS[3]);
282        for i in 4..16 {
283            x[i] = x[i].wrapping_add(mask[i - 4]);
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_ascon() {
294        let mut key = [0u8; StreamCipher::KEY_LENGTH];
295        getrandom::fill(&mut key).unwrap();
296
297        let st = StreamCipher::new(&key, b"testtest");
298
299        let mut out = [0u8; 10000];
300        st.apply_keystream(&mut out, 10).unwrap();
301
302        let mut out2 = [0u8; 10000];
303        st.fill(&mut out2, 10).unwrap();
304
305        assert_eq!(out, out2);
306
307        st.fill(&mut out2, 11).unwrap();
308        assert_eq!(out[1..], out2[0..out2.len() - 1]);
309
310        out.fill(0);
311        st.apply_keystream(&mut out, 0).unwrap();
312        st.fill(&mut out2, 0).unwrap();
313        assert_eq!(out, out2);
314
315        st.fill(&mut out, 64).unwrap();
316        assert_eq!(out[..out2.len() - 64], out2[64..]);
317    }
318}