1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
//! Intermediate types used for encryption and decryption <span style="color:red">**HAZMAT**</span>
//!
//! This module provides wrapper types for shuffling data back and forth between encrypted and
//! unencrypted representations, as well as the ability to, optionally, transparently compress plaintext
//! before encryption.
//!
//! # <span style="color:red">**DANGER**</span>
//!
//! This module deals in low level cryptographic details. It is advisable to not deal with this module
//! directly, and instead use a higher level API.

use std::{
    borrow::Cow,
    io::Cursor,
    ops::{Deref, DerefMut},
};

use chacha20::{
    cipher::{NewCipher, StreamCipher},
    XChaCha20,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use snafu::{ensure, ResultExt};
use zeroize::Zeroize;
use zstd::stream::encode_all;

use crate::{
    crypto::key::{Key, Nonce},
    error::{BadHMAC, Compression, Decompression, Error},
};

/// An unencrypted blob of plaintext.
///
/// This type exists to facilitate marshaling data to a serialized, encrypted, representation
#[derive(Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Zeroize)]
#[zeroize(drop)]
pub struct ClearText {
    /// The cleartext payload
    pub(crate) payload: Vec<u8>,
}

impl ClearText {
    /// Creates a new `Cleartext` from a serializeable object
    ///
    /// # Errors
    ///
    /// Will return an error if serialization fails
    pub fn new<T>(item: &T) -> Result<Self, Error>
    where
        T: Serialize,
    {
        // Attempt to serialize the item
        match serde_cbor::to_vec(item) {
            Ok(payload) => Ok(ClearText { payload }),
            Err(_) => {
                // Do not preserve the underlying serde error, as this may secrets into the logs
                Err(Error::ItemSerialization)
            }
        }
    }

    /// Attempts to create an encrypted version of this `Cleartext` using the provided key. `ZStd`
    /// compression will be applied to the plain text before encryption with the provided level, if the
    /// compression option is set with a `Some` value.
    ///
    /// # <span style="color:red">**DANGER**</span>
    ///
    /// Compression can be incredibly dangerous when combined with encryption, _do not_ set the compression
    /// flag unless you know what you are doing, and you are 100% sure that compression related attacks do
    /// not fall into your threat model.
    ///
    ///
    /// # Errors
    ///
    /// Will return an error if encryption fails.
    pub fn encrypt<K>(self, key: &K, compression: Option<i32>) -> Result<CipherText<'static>, Error>
    where
        K: Key,
    {
        // Compress the payload, if requested
        let mut payload = if let Some(level) = compression {
            let input = Cursor::new(&self.payload);
            encode_all(input, level).context(Compression)?
        } else {
            self.payload.clone()
        };
        // Perform the encryption
        let nonce = Nonce::random();
        let mut chacha = XChaCha20::new(key.encryption_key(), nonce.nonce());
        chacha.apply_keystream(&mut payload[..]);
        // Generate the hmac tag
        let hmac: [u8; 32] = blake3::keyed_hash(key.hmac_key(), &payload[..]).into();
        Ok(CipherText {
            compressed: compression.is_some(),
            nonce,
            hmac: hmac.into(),
            payload: payload.into(),
        })
    }

    /// Converts this `Cleartext` back into its original type.
    ///
    /// # Errors
    ///
    /// Will return `Err(Error::ItemDeserialization)` if the serialization. Be warned, since this
    /// intentionally erases the underlying `serde` error, it can be difficult to tell if this is due to
    /// data corruption, or simply calling this method with the wrong type argument.
    pub fn deserialize<T>(&self) -> Result<T, Error>
    where
        T: DeserializeOwned,
    {
        match serde_cbor::from_slice(&self.payload) {
            Ok(x) => Ok(x),
            Err(_) => Err(Error::ItemDeserialization),
        }
    }
}

/// An encrypted plaintext, with associated data
///
/// This structure contains the payload, encrypted with `XChaCha20`, the nonce that was used to encrypt
/// it, as well as the HMAC of the encrypted payload. It also includes a flag indicating whether or not
/// to treat the plaintext as compressed
#[derive(Debug, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct CipherText<'a> {
    /// Whether or not the payload is compressed
    pub(crate) compressed: bool,
    /// The nonce used for the encryption
    pub(crate) nonce: Nonce,
    /// The HMAC tag
    pub(crate) hmac: ConstArray<32>,
    /// The encrypted payload
    #[serde(serialize_with = "serde_bytes::serialize")]
    pub(crate) payload: Cow<'a, [u8]>,
}

impl CipherText<'_> {
    /// Attempts to decrypt the `Ciphertext` with the given key, turning it into a `Cleartext`. This will
    /// also decompress the payload, if it was compressed.
    ///
    /// # Errors
    ///
    /// Will return:
    ///   * `Error::BadHMAC` - If the hmac tag fails to validate (decryption failure)
    ///   * `Error::Decompression` - If compressed data fails to decompress
    pub fn decrypt<K>(&self, key: &K) -> Result<ClearText, Error>
    where
        K: Key,
    {
        // Verify the mac
        let hmac = blake3::keyed_hash(key.hmac_key(), &self.payload[..]);
        ensure!(hmac.eq(&*self.hmac), BadHMAC);
        // Copy the bytes into a local zeroizing buffer, and decrypt
        let mut payload = self.payload.to_vec();
        let mut chacha = XChaCha20::new(key.encryption_key(), self.nonce.nonce());
        chacha.apply_keystream(&mut payload[..]);
        // Uncompress the payload, if needed, otherwise, return as is
        if self.compressed {
            let input = Cursor::new(&payload[..]);
            let output = zstd::decode_all(input).context(Decompression)?;
            // Zeroize compressed payload
            payload.zeroize();
            Ok(ClearText { payload: output })
        } else {
            Ok(ClearText { payload })
        }
    }

    /// Returns true if this `CipherText` is compressed
    pub fn compressed(&self) -> bool {
        self.compressed
    }
}

/// Thin wrapper around a const-generic array to make them work better with serde
#[derive(Debug, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Zeroize)]
#[zeroize(drop)]
pub struct ConstArray<const N: usize>(pub [u8; N]);

impl<const N: usize> Deref for ConstArray<N> {
    type Target = [u8];

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<const N: usize> DerefMut for ConstArray<N> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl<const N: usize> From<[u8; N]> for ConstArray<N> {
    fn from(x: [u8; N]) -> Self {
        Self(x)
    }
}

impl<const N: usize> From<ConstArray<N>> for [u8; N] {
    fn from(val: ConstArray<N>) -> Self {
        val.0
    }
}

impl<const N: usize> Serialize for ConstArray<N> {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        let bytes = serde_bytes::Bytes::new(&self.0);
        bytes.serialize(serializer)
    }
}

impl<const N: usize> AsRef<[u8; N]> for ConstArray<N> {
    fn as_ref(&self) -> &[u8; N] {
        &self.0
    }
}

impl<'de, const N: usize> Deserialize<'de> for ConstArray<N> {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let bytes = serde_bytes::ByteBuf::deserialize(deserializer)?;
        let mut new_bytes = [0_u8; N];
        if bytes.len() == N {
            new_bytes.copy_from_slice(&bytes);
            Ok(Self(new_bytes))
        } else {
            Err(serde::de::Error::invalid_length(
                bytes.len(),
                &"Wrong length",
            ))
        }
    }
}

/// Unit tests
#[cfg(test)]
mod tests {
    use super::*;
    use crate::crypto::key::RootKey;
    /// Tests for the cleartext/ciphertext pair of types
    mod text {
        use super::*;
        /// Test round trip without compression
        #[test]
        fn round_trip() {
            let key = RootKey::random();
            let item = "The quick brown fox jumps over the lazy dog";
            let cleartext = ClearText::new(&item).expect("Failed to make cleartext");
            let ciphertext = cleartext.encrypt(&key, None).expect("Failed to encrypt");
            let decrypted = ciphertext.decrypt(&key).expect("Failed to decrypt");
            let decrypted_item: String = decrypted.deserialize().expect("Failed to deserialize");
            assert_eq!(decrypted_item, item);
        }
        /// Test round trip with compression
        #[test]
        fn round_trip_compression() {
            let key = RootKey::random();
            let item = "The quick brown fox jumps over the lazy dog";
            let cleartext = ClearText::new(&item).expect("Failed to make cleartext");
            let ciphertext = cleartext.encrypt(&key, Some(0)).expect("Failed to encrypt");
            let decrypted = ciphertext.decrypt(&key).expect("Failed to decrypt");
            let decrypted_item: String = decrypted.deserialize().expect("Failed to deserialize");
            assert_eq!(decrypted_item, item);
        }
        /// Make sure repeated invokations are non-equal
        #[test]
        fn repeated_invokations() {
            let key = RootKey::random();
            let item = "The quick brown fox jumps over the lazy dog";
            let cleartext = ClearText::new(&item).expect("Failed to make cleartext");
            let ciphertext_1 = cleartext
                .clone()
                .encrypt(&key, None)
                .expect("Failed to encrypt");
            let ciphertext_2 = cleartext.encrypt(&key, None).expect("Failed to encrypt");
            assert_ne!(ciphertext_1.nonce, ciphertext_2.nonce);
            assert_ne!(ciphertext_1.payload, ciphertext_2.payload);
        }
        /// Make sure corrupted data doesn't decrypt
        #[test]
        fn corruption() {
            let key = RootKey::random();
            let item = "The quick brown fox jumps over the lazy dog";
            let cleartext = ClearText::new(&item).expect("Failed to make cleartext");
            let mut ciphertext = cleartext.encrypt(&key, Some(0)).expect("Failed to encrypt");
            // Corrupt the first byte of the payload
            ciphertext.payload.to_mut()[0] = ciphertext.payload[0].wrapping_add(1_u8);
            let decrypted = ciphertext.decrypt(&key);
            match decrypted {
                Ok(_) => panic!("Somehow decrypted corrupted data"),
                Err(e) => assert!(matches!(e, Error::BadHMAC)),
            }
        }
        /// Make sure data can't be decrypted with the wrong key
        #[test]
        fn wrong_key() {
            let key = RootKey::random();
            let wrong_key = RootKey::random();
            let item = "The quick brown fox jumps over the lazy dog";
            let cleartext = ClearText::new(&item).expect("Failed to make cleartext");
            let ciphertext = cleartext.encrypt(&key, Some(0)).expect("Failed to encrypt");
            let decrypted = ciphertext.decrypt(&wrong_key);
            match decrypted {
                Ok(_) => panic!("Somehow decrypted corrupted data"),
                Err(e) => assert!(matches!(e, Error::BadHMAC)),
            }
        }
    }
}