strong_box/
static_strong_box.rs

1use chacha20poly1305::{
2	ChaCha20Poly1305, KeyInit as _,
3	aead::{Aead as _, Payload},
4};
5use rand::{RngCore, rng};
6use std::{collections::HashMap, fmt::Debug};
7
8use super::{Error, Key, KeyId, StrongBox};
9
10/// A secure symmetric encryption container, supporting key rotation and AAD contexts.
11///
12/// This is your basic, Mark 1 mod 0 [`StrongBox`].  Given an encryption key, it will
13/// encrypt data all day long with a modern, fast cipher (ChaCha20) with integrity protection and
14/// authenticated additional data (using Poly1305).  If provided with one or more decryption keys,
15/// it will decrypt data that was encrypted with *any* of those keys, giving you the ability to
16/// "rotate" your key over time, by creating a new key, making it the new encryption key, and
17/// keeping the old key in the set of "decryption" keys until such time as all data has been
18/// re-encrypted with the new key.
19///
20/// The "authenticated additional data" is a mouthful, but what it means is that when you encrypt
21/// data, you provide the encryption with a "context", such as the ID of the user that the
22/// encrypted data belongs to.  When you decrypt the data again, you provide the ID of the user the
23/// data belongs to, and if they don't match, decryption fails.  Why is that useful?  Because if
24/// an attacker gets write access to the database, and moves encrypted data from one user to
25/// another, Bad Things can happen.  [This Security StackExchange answer](https://security.stackexchange.com/a/179279/167630) is an excellent explanation of
26/// why an encryption context is useful.
27///
28/// # Example
29///
30/// ```rust
31/// use strong_box::{Error, StaticStrongBox, StrongBox};
32/// # fn main() -> Result<(), Error> {
33///
34/// // A couple of keys are always useful to have
35/// let old_key = strong_box::generate_key();
36/// let new_key = strong_box::generate_key();
37///
38/// let old_strongbox = StaticStrongBox::new(old_key.clone(), [old_key.clone()]);
39/// let new_strongbox = StaticStrongBox::new(new_key.clone(), [new_key.clone()]);
40/// // This StaticStrongBox encrypts with `new_key`, but can decrypt ciphertexts
41/// // encrypted with *either* `new_key` *or* `old_key`
42/// let fallback_strongbox = StaticStrongBox::new(new_key.clone(), vec![new_key, old_key]);
43///
44/// /////////////////////////////////////////////////////////
45/// // A ciphertext encrypted using the old key
46///
47/// let ciphertext = old_strongbox.encrypt(b"Hello, old world!", b"some context")?;
48///
49/// // We'd *hope* that we can decrypt what we encrypted
50/// assert_eq!(
51///     b"Hello, old world!".to_vec(),
52///     old_strongbox.decrypt(&ciphertext, b"some context")?
53/// );
54///
55/// // A StaticStrongBox that uses a different key won't be able to decrypt
56/// let result = new_strongbox.decrypt(&ciphertext, b"some context");
57/// assert!(matches!(result, Err(Error::Decryption)));
58///
59/// // Also, a StaticStrongBox that uses the right key won't decrypt if the context isn't the
60/// // same as was used to encrypt
61/// let result = old_strongbox.decrypt(&ciphertext, b"a different context");
62/// assert!(matches!(result, Err(Error::Decryption)));
63///
64/// // However, magic of magicks, the fallback StaticStrongBox can do the job!
65/// assert_eq!(
66///     b"Hello, old world!".to_vec(),
67///     fallback_strongbox.decrypt(&ciphertext, b"some context")?
68/// );
69///
70/// //////////////////////////////////////////////////////////////
71/// // Now, let's try a ciphertext encrypted using the new key
72///
73/// let ciphertext = new_strongbox.encrypt(b"Hello, new world!", b"new context")?;
74///
75/// // Again, the same StaticStrongBox should be able to decrypt
76/// assert_eq!(
77///     b"Hello, new world!".to_vec(),
78///     new_strongbox.decrypt(&ciphertext, b"new context")?
79/// );
80///
81/// // Unsurprisingly, the fallback StaticStrongBox can decrypt it too, as it uses the same key
82/// assert_eq!(
83///     b"Hello, new world!".to_vec(),
84///     fallback_strongbox.decrypt(&ciphertext, b"new context")?
85/// );
86///
87/// // A StaticStrongBox using just the old key won't be able to decrypt, though
88/// let result = old_strongbox.decrypt(&ciphertext, b"new context");
89/// assert!(matches!(result, Err(Error::Decryption)));
90///
91/// // And again, the right StaticStrongBox but the wrong context won't decrypt
92/// let result = new_strongbox.decrypt(&ciphertext, b"some other context");
93/// assert!(matches!(result, Err(Error::Decryption)));
94/// # Ok(())
95/// # }
96/// ```
97#[derive(Clone, Debug)]
98pub struct StaticStrongBox {
99	encryption_key: Key,
100	encryption_key_id: KeyId,
101	decryption_keys: HashMap<KeyId, Key>,
102}
103
104impl StaticStrongBox {
105	/// Create a new [`StaticStrongBox`].
106	#[tracing::instrument(level = "debug", skip(enc_key, dec_keys))]
107	pub fn new(
108		enc_key: impl Into<Key>,
109		dec_keys: impl IntoIterator<Item = impl Into<Key>>,
110	) -> Self {
111		let mut key_map: HashMap<KeyId, Key> = HashMap::default();
112
113		for key in dec_keys.into_iter() {
114			let key = key.into();
115			let key_id = super::key_id(&key);
116			tracing::debug!(%key_id, "Including decryption key");
117			key_map.insert(key_id, key);
118		}
119
120		let enc_key = enc_key.into();
121		let enc_key_id = super::key_id(&enc_key);
122		tracing::debug!("Encryption key is {enc_key_id}");
123
124		Self {
125			encryption_key_id: enc_key_id,
126			encryption_key: enc_key,
127			decryption_keys: key_map,
128		}
129	}
130
131	pub(crate) fn decrypt_ciphertext(
132		&self,
133		ciphertext: &Ciphertext,
134		ctx: &[u8],
135	) -> Result<Vec<u8>, Error> {
136		if let Some(key) = self.decryption_keys.get(&ciphertext.key_id) {
137			tracing::debug!(key_id=%ciphertext.key_id, "Decrypting");
138
139			let mut aad = Vec::<u8>::new();
140			aad.extend_from_slice(ctx.as_ref());
141			aad.extend_from_slice(ciphertext.key_id.as_bytes());
142			aad.extend_from_slice(&ciphertext.nonce);
143
144			let cipher = ChaCha20Poly1305::new(key.expose_secret().into());
145			let payload = Payload {
146				msg: &ciphertext.ciphertext,
147				aad: &aad,
148			};
149			cipher
150				.decrypt((&ciphertext.nonce[..]).into(), payload)
151				.map_err(|_| Error::Decryption)
152		} else {
153			tracing::debug!(key_id=%ciphertext.key_id, "Decryption key not found");
154			Err(Error::Decryption)
155		}
156	}
157}
158
159impl StrongBox for StaticStrongBox {
160	#[tracing::instrument(level = "debug", skip(plaintext))]
161	fn encrypt(
162		&self,
163		plaintext: impl AsRef<[u8]>,
164		ctx: impl AsRef<[u8]> + Debug,
165	) -> Result<Vec<u8>, Error> {
166		let cipher = ChaCha20Poly1305::new((self.encryption_key.expose_secret()).into());
167		let mut rng = rng();
168		let mut nonce = [0u8; 12];
169		rng.fill_bytes(&mut nonce);
170
171		let mut aad = Vec::<u8>::new();
172		aad.extend_from_slice(ctx.as_ref());
173		aad.extend_from_slice(self.encryption_key_id.as_bytes());
174		aad.extend_from_slice(&nonce);
175
176		let ciphertext = cipher
177			.encrypt(
178				(&nonce).into(),
179				Payload {
180					msg: plaintext.as_ref(),
181					aad: &aad,
182				},
183			)
184			.map_err(|_| Error::Encryption)?;
185		tracing::debug!(key_id=%self.encryption_key_id, "Encrypting");
186
187		Ciphertext::new(self.encryption_key_id, nonce, ciphertext).to_bytes()
188	}
189
190	#[tracing::instrument(level = "debug", skip(ciphertext))]
191	fn decrypt(
192		&self,
193		ciphertext: impl AsRef<[u8]>,
194		ctx: impl AsRef<[u8]> + Debug,
195	) -> Result<Vec<u8>, Error> {
196		let ciphertext = Ciphertext::try_from(ciphertext.as_ref())?;
197
198		self.decrypt_ciphertext(&ciphertext, ctx.as_ref())
199	}
200}
201
202// This makes more sense in base64
203const CIPHERTEXT_MAGIC: [u8; 3] = [0xb1, 0xb8, 0xf5];
204
205#[derive(Clone, Debug)]
206pub(crate) struct Ciphertext {
207	pub(crate) key_id: KeyId,
208	pub(crate) nonce: [u8; 12],
209	pub(crate) ciphertext: Vec<u8>,
210}
211
212impl Ciphertext {
213	pub(crate) fn new(key_id: KeyId, nonce: [u8; 12], ciphertext: Vec<u8>) -> Self {
214		Self {
215			key_id,
216			nonce,
217			ciphertext,
218		}
219	}
220
221	pub(crate) fn to_bytes(&self) -> Result<Vec<u8>, Error> {
222		use ciborium_ll::{Encoder, Header};
223
224		let mut v: Vec<u8> = Vec::new();
225
226		v.extend_from_slice(&CIPHERTEXT_MAGIC);
227
228		let mut enc = Encoder::from(&mut v);
229		enc.push(Header::Array(Some(3)))
230			.map_err(|e| Error::ciphertext_encoding("key_id", e))?;
231		self.key_id.encode(&mut enc)?;
232		enc.bytes(&self.nonce, None)
233			.map_err(|e| Error::ciphertext_encoding("nonce", e))?;
234		enc.bytes(&self.ciphertext, None)
235			.map_err(|e| Error::ciphertext_encoding("ciphertext", e))?;
236
237		tracing::debug!(
238			nonce = self
239				.nonce
240				.iter()
241				.map(|i| format!("{i:02x}"))
242				.collect::<Vec<_>>()
243				.join(""),
244			ct = self
245				.ciphertext
246				.iter()
247				.map(|i| format!("{i:02x}"))
248				.collect::<Vec<_>>()
249				.join(""),
250			"{}",
251			v.iter()
252				.map(|i| format!("{i:02x}"))
253				.collect::<Vec<_>>()
254				.join("")
255		);
256		Ok(v)
257	}
258}
259
260impl TryFrom<&[u8]> for Ciphertext {
261	type Error = Error;
262
263	fn try_from(b: &[u8]) -> Result<Self, Self::Error> {
264		use ciborium_ll::{Decoder, Header};
265
266		if b.len() < 21 {
267			return Err(Error::invalid_ciphertext("too short"));
268		}
269
270		if b[0..3] != CIPHERTEXT_MAGIC {
271			tracing::debug!(magic=?CIPHERTEXT_MAGIC, actual=?b[0..3]);
272			return Err(Error::invalid_ciphertext("incorrect magic"));
273		}
274
275		let mut dec = Decoder::from(&b[3..]);
276
277		let Header::Array(Some(3)) = dec
278			.pull()
279			.map_err(|e| Error::ciphertext_decoding("array", e))?
280		else {
281			return Err(Error::invalid_ciphertext("expected array"));
282		};
283
284		let key_id = KeyId::decode(&mut dec)?;
285
286		// CBOR's great, until you have to deal with segmented bytestrings...
287		let Header::Bytes(len) = dec
288			.pull()
289			.map_err(|e| Error::ciphertext_decoding("nonce header", e))?
290		else {
291			return Err(Error::invalid_ciphertext("expected nonce"));
292		};
293
294		let mut segments = dec.bytes(len);
295
296		let Ok(Some(mut segment)) = segments.pull() else {
297			return Err(Error::invalid_ciphertext("bad nonce"));
298		};
299
300		let mut buf = [0u8; 1024];
301		let mut nonce = [0u8; 12];
302
303		if let Some(chunk) = segment
304			.pull(&mut buf[..])
305			.map_err(|e| Error::ciphertext_decoding("nonce", e))?
306		{
307			// Is this necessary?  Probably better to be safe than sorry
308			nonce[..].copy_from_slice(chunk);
309		} else {
310			return Err(Error::invalid_ciphertext("short nonce"));
311		}
312
313		// ibid.
314		let Header::Bytes(len) = dec
315			.pull()
316			.map_err(|e| Error::ciphertext_decoding("ciphertext header", e))?
317		else {
318			return Err(Error::invalid_ciphertext("expected ciphertext"));
319		};
320
321		let mut segments = dec.bytes(len);
322
323		let Ok(Some(mut segment)) = segments.pull() else {
324			return Err(Error::invalid_ciphertext("bad ciphertext"));
325		};
326
327		let mut ciphertext: Vec<u8> = Vec::new();
328
329		while let Some(chunk) = segment
330			.pull(&mut buf[..])
331			.map_err(|e| Error::ciphertext_decoding("ciphertext", e))?
332		{
333			ciphertext.extend_from_slice(chunk);
334		}
335
336		Ok(Self {
337			key_id,
338			nonce,
339			ciphertext,
340		})
341	}
342}