1#![cfg(target_endian="little")]
98#![doc(cfg(target_endian="little"))]
99
100use {
101 alloc::vec::Vec,
102 crate::{Bytes, Result},
103 self::word_array::WordArray,
104};
105
106#[cfg(not(feature="std"))]
107use crate::io::Write;
108
109#[cfg(feature="std")]
110use {
111 std::io::Write,
112 crate::IoResult,
113};
114
115#[cfg(feature="simd")]
116use core::simd::Simd;
117
118#[cfg(test)]
119use {
120 core::mem,
121 crate::keccak::Hash,
122 openssl::symm::{Cipher, Crypter, Mode},
123};
124
125#[cfg(test)]
126#[cfg(feature="std")]
127use std::thread;
128
129#[cfg(test)]
130mod tests;
131
132mod variant;
133
134pub use self::variant::*;
135
136#[cfg(not(feature="simd"))]
137#[doc(cfg(not(feature="simd")))]
138pub (crate) mod salsa20;
139
140#[cfg(feature="simd")]
141#[doc(cfg(feature="simd"))]
142pub (crate) mod salsa20_simd;
143
144pub (crate) mod word_array;
145
146pub type Key = [u8; KEY_SIZE_IN_BYTES];
148
149const KEY_SIZE_IN_BYTES: usize = 32;
151
152pub type Nonce = [u8; NONCE_SIZE_IN_BYTES];
162
163const NONCE_SIZE_IN_BYTES: usize = 16;
165
166pub (crate) const BLOCK_SIZE_IN_BYTES: usize = 64;
168
169pub (crate) type Block = [u8; BLOCK_SIZE_IN_BYTES];
171
172#[cfg(any(feature="simd", test))]
173const BLOCK_OF_ZEROS: Block = [0; BLOCK_SIZE_IN_BYTES];
174
175#[cfg(test)]
176const KEY_HASH: Hash = Hash::Sha3_256;
177#[cfg(test)]
178const NONCE_HASH: Hash = Hash::Shake128;
179
180pub (crate) const TWENTY: Variant = Variant::Twenty;
182
183#[derive(Debug)]
189pub struct Chacha<W> where W: Write {
190 variant: Variant,
191 data: WordArray,
192 buffer: Vec<u8>,
193 output: W,
194}
195
196impl<W> Chacha<W> where W: Write {
197
198 pub const fn variant(&self) -> &Variant {
200 &self.variant
201 }
202
203 #[inline]
209 pub fn encrypt<B>(&mut self, bytes: B) -> Result<usize> where B: AsRef<[u8]> {
210 let mut bytes = bytes.as_ref();
211
212 let result = bytes.len();
213 if result == 0 {
214 return Ok(result);
215 }
216
217 if crate::io::fill_buffer(&mut self.buffer, BLOCK_SIZE_IN_BYTES, &mut bytes) {
219 self.process_buffer_and_write_to_output(None)?;
220 self.buffer.clear();
221 }
222
223 {
225 let chunks = bytes.chunks_exact(BLOCK_SIZE_IN_BYTES);
227 self.buffer.extend(chunks.remainder());
228 for c in chunks {
229 self.process_buffer_and_write_to_output(Some(c))?;
230 }
231 }
232
233 Ok(result)
234 }
235
236 pub fn encrypt_bytes<'a, const N: usize, B, B0>(&mut self, bytes: B) -> Result<Option<usize>>
242 where B: Into<Bytes<'a, N, B0>>, B0: AsRef<[u8]> + 'a {
243 let mut result = Some(usize::MIN);
244 for bytes in bytes.into().as_slice() {
245 let size = self.encrypt(bytes)?;
246 if let Some(current) = result.as_mut() {
247 match current.checked_add(size) {
248 Some(new) => *current = new,
249 None => result = None,
250 };
251 }
252 }
253
254 Ok(result)
255 }
256
257 #[inline]
264 fn process_buffer_and_write_to_output(&mut self, buffer: Option<&[u8]>) -> Result<()> {
265 let buffer = buffer.unwrap_or(&self.buffer);
266
267 let buffer_len = match buffer.len() {
268 0 => return Ok(()),
269 buffer_len @ 1..=BLOCK_SIZE_IN_BYTES => buffer_len,
270 other => return Err(err!("Buffer is too large: {other} (max allowed: {max})", other=other, max=BLOCK_SIZE_IN_BYTES)),
271 };
272
273 #[cfg(feature="simd")]
274 let mut output = salsa20_simd::words_to_bytes(&self.variant, &self.data);
275 #[cfg(not(feature="simd"))]
276 let mut output = salsa20::words_to_bytes(&self.variant, &self.data);
277 {
278 const TWELFTH: usize = 12;
279 self.data[TWELFTH] = self.data[TWELFTH].wrapping_add(1);
280 if self.data[TWELFTH] == 0 {
281 const THIRTEENTH: usize = TWELFTH + 1;
282 self.data[THIRTEENTH] = self.data[THIRTEENTH].wrapping_add(1);
283 }
284 }
285
286 #[cfg(feature="simd")] {
287 output = {
288 let mut tmp = BLOCK_OF_ZEROS;
289 tmp[..buffer_len].copy_from_slice(buffer);
290 (Simd::from_array(output) ^ Simd::from_array(tmp)).to_array()
291 };
292 }
293 #[cfg(not(feature="simd"))]
294 (0..buffer_len).for_each(|i| output[i] ^= buffer[i]);
295
296 let result = self.output.write_all(&output[..buffer_len]);
297 #[cfg(feature="std")]
298 let result = result.map_err(|e| from_io_err!(e));
299 result
300 }
301
302 #[cfg(feature="std")]
304 #[inline(always)]
305 pub (crate) const fn mut_output(&mut self) -> &mut W {
306 &mut self.output
307 }
308
309 #[must_use]
311 pub fn finish(mut self) -> Result<W> {
312 self.process_buffer_and_write_to_output(None)?;
313 drop(self.buffer);
314
315 #[cfg(feature="std")]
316 from_io_err!(self.output.flush())?;
317
318 Ok(self.output)
319 }
320
321}
322
323#[cfg(feature="std")]
324#[doc(cfg(feature="std"))]
325impl<W> Write for Chacha<W> where W: Write {
326
327 fn write(&mut self, buffer: &[u8]) -> IoResult<usize> {
328 Ok(self.encrypt(buffer)?)
329 }
330
331 fn flush(&mut self) -> IoResult<()> {
332 self.output.flush()
333 }
334
335}
336
337#[test]
338fn tests() {
339 assert_eq!(mem::size_of::<Block>(), mem::size_of::<WordArray>());
340
341 assert_eq!(KEY_SIZE_IN_BYTES, 32);
342 assert_eq!(mem::size_of::<Key>(), KEY_SIZE_IN_BYTES);
343
344 assert_eq!(NONCE_SIZE_IN_BYTES, 16);
345 assert_eq!(mem::size_of::<Nonce>(), NONCE_SIZE_IN_BYTES);
346}
347
348#[cfg(test)]
350fn encrypt_using_open_ssl_chacha20<K, N, B>(mode: Mode, key: K, nonce: N, bytes: B) -> Vec<u8>
351where K: AsRef<[u8]>, N: AsRef<[u8]>, B: AsRef<[u8]> {
352 let bytes = bytes.as_ref();
353
354 let mut crypter = Crypter::new(Cipher::chacha20(), mode, key.as_ref(), Some(nonce.as_ref())).unwrap();
355 let mut result = Vec::with_capacity(bytes.len());
356
357 let mut output = BLOCK_OF_ZEROS;
358 for c in bytes.chunks(BLOCK_SIZE_IN_BYTES) {
359 let count = crypter.update(c, &mut output).unwrap();
360 result.extend(&output[..count]);
361 }
362 {
363 let count = crypter.finalize(&mut output).unwrap();
364 result.extend(&output[..count]);
365 }
366
367 result
368}
369
370#[test]
379fn cmp_to_data_generated_by_open_ssl() -> Result<()> {
380 const DATA: &[u8] = &[
381 0x50, 0xa6, 0xb7, 0xec, 0xb4, 0x2c, 0xc8, 0x9a, 0xbd, 0x5c, 0x40, 0xea, 0x5e, 0x30, 0x66, 0x31, 0xbf, 0x93, 0x49, 0x6a, 0xc5, 0x01,
382 0x56, 0xb3, 0x6f, 0x51, 0x56, 0xe8, 0x56, 0x89, 0xe8, 0x56, 0x08, 0xe5, 0xb0, 0xe6, 0xa0, 0xd9, 0x3c, 0x1c, 0x6a, 0x8a, 0xd7, 0x12,
383 0x09, 0xf0, 0xac, 0x9d, 0x57, 0xb0, 0x45, 0x33, 0x9d, 0x1f, 0x60, 0x1d, 0x34, 0xf8, 0xa7, 0xa0, 0x4e, 0x42, 0x64, 0xae,
384 0x98, 0xda, 0xad, 0x4a, 0x94, 0x82, 0xcc, 0x9b, 0x84, 0x82, 0x1f, 0x50, 0x4f, 0xc0, 0x44, 0xba,
385 ];
386
387 let key = &KEY_HASH.hash("key");
388 let nonce = &NONCE_HASH.hash("nonce");
389 let encrypt = |data| TWENTY.encrypt(key, nonce, data);
390 let encrypt_using_open_ssl_chacha20 = |mode, data| encrypt_using_open_ssl_chacha20(mode, key, nonce, data);
391
392 let encrypted = encrypt(DATA)?;
393 assert_eq!(DATA.len(), encrypted.len());
394 assert_ne!(DATA, encrypted);
395 assert_eq!(encrypted, encrypt_using_open_ssl_chacha20(Mode::Encrypt, DATA));
396
397 assert_eq!(DATA, encrypt(&encrypted)?);
398 assert_eq!(DATA, encrypt_using_open_ssl_chacha20(Mode::Decrypt, &encrypted));
399
400 Ok(())
401}
402
403#[test]
404#[cfg(feature="std")]
405fn cmp_to_data_generated_by_open_ssl_using_threads() -> Result<()> {
406 const INPUT_DATA: &[u8] = &[u8::MIN; BLOCK_SIZE_IN_BYTES + 1];
407
408 (-1_i16..=999).map(|index| thread::spawn(move || {
409 let key = &if index >= 0 {
410 KEY_HASH.hash(index.to_be_bytes())
411 } else {
412 alloc::vec![u8::MAX; KEY_SIZE_IN_BYTES]
414 };
415 let nonce = &if index % 2 == 0 {
416 alloc::vec![index as u8; NONCE_SIZE_IN_BYTES]
417 } else {
418 NONCE_HASH.hash(index.to_le_bytes())
419 };
420 assert_ne!(key, nonce);
421
422 let encrypt = |data| TWENTY.encrypt(key, nonce, data);
423 let encrypt_using_open_ssl_chacha20 = |mode, data| encrypt_using_open_ssl_chacha20(mode, key, nonce, data);
424
425 let encrypted_data = encrypt(INPUT_DATA)?;
427 assert_eq!(INPUT_DATA.len(), encrypted_data.len());
428 assert_eq!(encrypted_data, encrypt_using_open_ssl_chacha20(Mode::Encrypt, INPUT_DATA));
429
430 assert_eq!(INPUT_DATA, encrypt(&encrypted_data)?);
432 assert_eq!(INPUT_DATA, encrypt_using_open_ssl_chacha20(Mode::Decrypt, &encrypted_data));
433
434 Result::Ok(())
435 })).collect::<Vec<_>>().into_iter().for_each(|t| t.join().unwrap().unwrap());
436
437 Ok(())
438}