strong_box/
rotating_strong_box.rs

1use parking_lot::RwLock;
2use std::{
3	collections::{BTreeMap, BTreeSet},
4	fmt::Debug,
5	sync::Arc,
6	time::Duration,
7};
8
9use super::{Ciphertext, Error, Key, KeyId, StaticStrongBox, StrongBox, kdf};
10
11/// A [`StrongBox`] variant that uses a different set of keys for each period of time.
12///
13/// This box is primarily useful when you're encrypting data that only has to be valid for a
14/// relatively short period of time (such as browser cookies and other shortish-lived tokens).  It
15/// is particularly beneficial if you could potentially do a lot (ie "many billions") of
16/// encryptions over the lifetime of the root keys (because, for various tedious reasons, the
17/// risk of a cryptographic break increases based on the number of different encryptions done by
18/// one key).
19///
20/// To use it, you need to specify the "root" encryption and decryption keys, as with any other strong
21/// box, but also the [`Duration`] for which key is valid, and also the number of previous
22/// time periods which you want to still be able to decrypt.
23///
24/// The way it works is that time is divided up into periods, each of which has a [`Duration`] of
25/// `lifespan`, specified when the [`RotatingStrongBox`] is created.  When an encryption operation is
26/// performed, the current period is determined (by looking at the clock and divided by `lifespan`),
27/// the encryption key for the current period is derived from the "root" encryption key, and the
28/// data is encrypted as normal with that "current encryption key".  So far, so good.
29///
30/// When a ciphertext is presented for decryption, things get a bit more involved.
31/// The [`RotatingStrongBox`] needs to figure out which key to use for decryption, by deriving the
32/// decryption keys from the "root" decryption keys specified when the [`RotatingStrongBox`] was
33/// created, both for the time period at the time the decryption happens, *as well as* the previous
34/// time periods, up to the limit specified by `backtrack`.  If the decryption key for the
35/// ciphertext is one of those keys, then decryption happens as normal.  Otherwise, you're out of
36/// luck, and the decryption fails.
37///
38/// Since deriving lots of keys can start to take a little bit of time, the set of decryption keys
39/// is cached, and also shared amongst all the clones of a given [`RotatingStrongBox`].  The
40/// maximum amount of memory that will be used by the cache (and the amount of time needed to
41/// derive all the keys) is determined by the number of separate decryption keys, multiplied by the
42/// number of `backtrack` periods allowed.  Each key is relatively small, so don't worry too much,
43/// but also don't go "oh, I'll make my key lifespan 30 seconds and cache keys for 10 years"
44/// without being ready for a certain amount of bloat.
45#[derive(Clone, Debug)]
46pub struct RotatingStrongBox {
47	// This is just a way for us to test that periods and keys are generated properly,
48	// by fiddling with time in unit tests
49	time: Clock,
50
51	key_cache: Arc<RwLock<KeyCache>>,
52}
53
54// We don't panic while holding the lock
55impl std::panic::UnwindSafe for RotatingStrongBox {}
56
57impl RotatingStrongBox {
58	#[tracing::instrument(level = "trace")]
59	pub(super) fn new(
60		enc_key: Key,
61		dec_keys: Vec<Key>,
62		lifespan: Duration,
63		backtrack: u16,
64	) -> Self {
65		Self {
66			#[cfg(not(test))]
67			time: Clock,
68			#[cfg(test)]
69			time: Clock::default(),
70			key_cache: Arc::new(RwLock::new(KeyCache {
71				lifespan,
72				backtrack,
73				root_encryption_key: enc_key,
74				root_decryption_keys: dec_keys,
75				current_encryptor: CachedStrongBox::new(Box::new([0u8; 32]).into(), Timestamp(0)),
76				decryptor_cache: BTreeMap::default(),
77				decryptor_validities: BTreeSet::default(),
78				cache_invalid_at: Timestamp::default(),
79			})),
80		}
81	}
82
83	#[tracing::instrument(level = "trace", skip(self, ciphertext))]
84	fn try_decrypt_with(&self, ciphertext: &Ciphertext, ctx: &[u8]) -> Result<Vec<u8>, Error> {
85		let key_cache = self.key_cache.read_arc();
86
87		if let Some(cached_strongbox) = key_cache.decryptor_cache.get(&ciphertext.key_id) {
88			if cached_strongbox.is_expired(self.time.now()) {
89				tracing::debug!(key_id=%ciphertext.key_id, "Key expired");
90				Err(Error::Decryption)
91			} else {
92				cached_strongbox
93					.strong_box
94					.decrypt_ciphertext(ciphertext, ctx)
95			}
96		} else {
97			tracing::debug!(key_id=%ciphertext.key_id, "Key not found");
98			Err(Error::Decryption)
99		}
100	}
101
102	#[cfg(test)]
103	fn timewarp(&mut self, secs: i64) {
104		self.time.timewarp(secs)
105	}
106}
107
108impl StrongBox for RotatingStrongBox {
109	#[tracing::instrument(level = "debug", skip(plaintext))]
110	fn encrypt(
111		&self,
112		plaintext: impl AsRef<[u8]>,
113		ctx: impl AsRef<[u8]> + Debug,
114	) -> Result<Vec<u8>, Error> {
115		let mut key_cache = self.key_cache.write_arc();
116		key_cache
117			.current_encryptor(self.time.now())
118			.strong_box
119			.encrypt(plaintext.as_ref(), ctx.as_ref())
120	}
121
122	#[tracing::instrument(level = "debug", skip(ciphertext))]
123	fn decrypt(
124		&self,
125		ciphertext: impl AsRef<[u8]>,
126		ctx: impl AsRef<[u8]> + Debug,
127	) -> Result<Vec<u8>, Error> {
128		fn inner(
129			this: &RotatingStrongBox,
130			ciphertext: &[u8],
131			ctx: &[u8],
132		) -> Result<Vec<u8>, Error> {
133			let ciphertext = Ciphertext::try_from(ciphertext)?;
134
135			if let Ok(plaintext) = this.try_decrypt_with(&ciphertext, ctx.as_ref()) {
136				Ok(plaintext)
137			} else {
138				let mut key_cache = this.key_cache.write_arc();
139				key_cache.refresh_cache(this.time.now());
140				// Drop this drop and we'll end up with a deadlock when we call into
141				// .try_decrypt_with()
142				drop(key_cache);
143				this.try_decrypt_with(&ciphertext, ctx.as_ref())
144			}
145		}
146		inner(self, ciphertext.as_ref(), ctx.as_ref())
147	}
148}
149
150#[derive(Clone, Copy, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
151#[repr(transparent)]
152struct Timestamp(u64);
153
154impl std::ops::Deref for Timestamp {
155	type Target = u64;
156
157	fn deref(&self) -> &u64 {
158		&self.0
159	}
160}
161
162impl std::fmt::Display for Timestamp {
163	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164		f.write_fmt(format_args!("{}", self.0))
165	}
166}
167
168impl From<u64> for Timestamp {
169	fn from(t: u64) -> Self {
170		Self(t)
171	}
172}
173
174impl std::ops::Add<u64> for Timestamp {
175	type Output = Timestamp;
176
177	fn add(self, t: u64) -> Self {
178		Self(self.0 + t)
179	}
180}
181
182impl std::ops::Add<i64> for Timestamp {
183	type Output = Timestamp;
184
185	fn add(self, t: i64) -> Self {
186		Self(self.0.checked_add_signed(t).unwrap())
187	}
188}
189
190impl std::ops::Add<Timestamp> for Timestamp {
191	type Output = Timestamp;
192
193	fn add(self, t: Timestamp) -> Self {
194		Self(self.0 + t.0)
195	}
196}
197
198impl std::ops::Add<Duration> for Timestamp {
199	type Output = Timestamp;
200
201	fn add(self, t: Duration) -> Self {
202		Self(self.0 + t.as_secs())
203	}
204}
205
206impl std::ops::Sub<u64> for Timestamp {
207	type Output = Timestamp;
208
209	fn sub(self, t: u64) -> Self {
210		Self(self.0 - t)
211	}
212}
213
214impl std::ops::Sub<Timestamp> for Timestamp {
215	type Output = Timestamp;
216
217	fn sub(self, t: Timestamp) -> Self {
218		Self(self.0 - t.0)
219	}
220}
221
222impl std::ops::Sub<Duration> for Timestamp {
223	type Output = Timestamp;
224
225	fn sub(self, t: Duration) -> Self {
226		Self(self.0 - t.as_secs())
227	}
228}
229
230impl std::ops::Mul<u16> for Timestamp {
231	type Output = Timestamp;
232
233	fn mul(self, n: u16) -> Self {
234		Self(self.0 * n as u64)
235	}
236}
237
238impl std::ops::Rem<Duration> for Timestamp {
239	type Output = Timestamp;
240
241	fn rem(self, t: Duration) -> Self {
242		Self(self.0 % t.as_secs())
243	}
244}
245
246#[derive(Debug)]
247struct KeyCache {
248	lifespan: Duration,
249	backtrack: u16,
250
251	root_encryption_key: Key,
252	root_decryption_keys: Vec<Key>,
253
254	current_encryptor: CachedStrongBox,
255
256	decryptor_cache: BTreeMap<KeyId, CachedStrongBox>,
257	// This is just a way of keeping an ordered set of the
258	// cached key IDs and their expiries
259	decryptor_validities: BTreeSet<(Timestamp, KeyId)>,
260	// Short circuit flag so we don't have to always be deriving keys
261	cache_invalid_at: Timestamp,
262}
263
264#[cfg(test)]
265impl Default for KeyCache {
266	fn default() -> Self {
267		Self {
268			lifespan: Duration::from_secs(0),
269			backtrack: 0,
270			root_encryption_key: Box::new([0; 32]).into(),
271			root_decryption_keys: Vec::default(),
272			current_encryptor: CachedStrongBox::new(Box::new([0u8; 32]).into(), Timestamp(0)),
273			decryptor_cache: BTreeMap::default(),
274			decryptor_validities: BTreeSet::default(),
275			cache_invalid_at: Timestamp::default(),
276		}
277	}
278}
279
280impl KeyCache {
281	fn current_encryptor(&mut self, t: Timestamp) -> &CachedStrongBox {
282		if self.current_encryptor.is_expired(t) {
283			let current_period = self.period(t);
284			self.current_encryptor = CachedStrongBox::new(
285				self.derive_key(&self.root_encryption_key, &current_period),
286				current_period.invalid_after,
287			);
288		}
289
290		&self.current_encryptor
291	}
292
293	#[tracing::instrument(level = "trace", skip(self))]
294	fn derive_key(&self, root_key: &Key, period: &Period) -> Key {
295		let mut context = b"rotation::".to_vec();
296		context.extend_from_slice(&period.to_bytes());
297
298		kdf::derive_key(root_key, &context)
299	}
300
301	#[tracing::instrument(level = "trace", skip(self))]
302	fn refresh_cache(&mut self, as_at: Timestamp) {
303		if as_at < self.cache_invalid_at {
304			return;
305		}
306
307		let mut current_period = self.period(as_at);
308
309		// By clearing out all expired keys first, we save memory *and* reduce the chances of the
310		// birthday paradox ruining our day
311		self.purge_entries_before(current_period.first_valid_at);
312
313		for b in 0..=(self.backtrack) {
314			// No point generating keys we've already generated
315			let trial_key_id = crate::key_id(
316				&self.derive_key(
317					self.root_decryption_keys
318						.first()
319						.expect("caller should have verified we have decryption keys!"),
320					&current_period,
321				),
322			);
323
324			if self.decryptor_cache.contains_key(&trial_key_id) {
325				// Since we always start from the "latest" time when adding new keys
326				// to the cache, if a key from *this* period is present, then all keys from
327				// *previous* periods must be present too
328				tracing::debug!("Cache refresh complete due to finding previous generation key");
329				return;
330			}
331
332			let invalid_after = current_period.invalid_after
333				+ (self.backtrack - b) as u64 * self.lifespan.as_secs();
334
335			// Mass key-creation time!
336			for key in &self.root_decryption_keys {
337				let key = self.derive_key(key, &current_period);
338				let key_id = crate::key_id(&key);
339				let strongbox = CachedStrongBox::new(key, invalid_after);
340
341				tracing::debug!(%key_id, %invalid_after,
342					"Adding key to cache",
343				);
344
345				self.decryptor_cache.insert(key_id, strongbox);
346				self.decryptor_validities.insert((invalid_after, key_id));
347			}
348
349			if let Some(previous_period) = current_period.previous() {
350				current_period = previous_period;
351			} else {
352				// We have reached the dawn of time... how the fuck did that happen?
353				tracing::debug!("Epoch reached");
354				return;
355			}
356		}
357
358		tracing::debug!("Cache fully populated");
359	}
360
361	#[tracing::instrument(level = "trace", skip(self))]
362	fn purge_entries_before(&mut self, t: Timestamp) {
363		loop {
364			let oldest_entry = self.oldest_cached_decryptor();
365
366			if let Some((expiry, key_id)) = oldest_entry {
367				if expiry < t {
368					tracing::debug!(%key_id, "Removing expired key");
369					self.decryptor_validities.remove(&(expiry, key_id));
370					self.decryptor_cache.remove(&key_id);
371				} else {
372					// Oldest entry is still valid, all done
373					return;
374				}
375			} else {
376				// Cache is empty, let's go home
377				return;
378			}
379		}
380	}
381
382	#[tracing::instrument(level = "trace", skip(self))]
383	fn oldest_cached_decryptor(&self) -> Option<(Timestamp, KeyId)> {
384		self.decryptor_validities.first().copied()
385	}
386
387	#[tracing::instrument(level = "trace", skip(self))]
388	fn period(&self, at: Timestamp) -> Period {
389		let first_valid_at = at - (at % self.lifespan);
390		let invalid_after = first_valid_at + self.lifespan - 1;
391
392		Period {
393			first_valid_at,
394			invalid_after,
395		}
396	}
397}
398
399// Represents a "chunk" of time during which a single temporal key is valid.
400#[derive(Clone, Debug, PartialEq)]
401struct Period {
402	first_valid_at: Timestamp,
403	invalid_after: Timestamp,
404}
405
406impl Period {
407	#[tracing::instrument(level = "trace", skip(self))]
408	fn to_bytes(&self) -> Vec<u8> {
409		let mut bytes = vec![];
410
411		bytes.extend_from_slice(&self.first_valid_at.to_be_bytes());
412		bytes.extend_from_slice(&self.invalid_after.to_be_bytes());
413
414		bytes
415	}
416
417	#[tracing::instrument(level = "trace")]
418	fn previous(&self) -> Option<Period> {
419		self.back_by(1)
420	}
421
422	#[tracing::instrument(level = "trace")]
423	fn back_by(&self, n: u16) -> Option<Period> {
424		let d = (self.invalid_after - self.first_valid_at + 1u64) * n;
425
426		if self.first_valid_at < d {
427			// Can't go back before the epoch!
428			None
429		} else {
430			Some(Period {
431				first_valid_at: self.first_valid_at - d,
432				invalid_after: self.invalid_after - d,
433			})
434		}
435	}
436}
437
438#[derive(Debug)]
439struct CachedStrongBox {
440	invalid_after: Timestamp,
441	strong_box: StaticStrongBox,
442}
443
444impl CachedStrongBox {
445	#[tracing::instrument(level = "trace", name = "CachedStrongBox::new")]
446	fn new(key: Key, invalid_after: Timestamp) -> Self {
447		Self {
448			invalid_after,
449			strong_box: StaticStrongBox::new(key.clone(), [key]),
450		}
451	}
452
453	#[tracing::instrument(level = "trace")]
454	fn is_expired(&self, now: Timestamp) -> bool {
455		now > self.invalid_after
456	}
457}
458
459#[cfg(not(test))]
460mod real_clock {
461	use super::Timestamp;
462	use std::time::{SystemTime, UNIX_EPOCH};
463
464	#[derive(Clone, Debug, Default)]
465	pub(super) struct Clock;
466
467	impl Clock {
468		#[tracing::instrument(level = "trace")]
469		pub(super) fn now(&self) -> Timestamp {
470			Timestamp(
471				SystemTime::now()
472					.duration_since(UNIX_EPOCH)
473					.unwrap()
474					.as_secs(),
475			)
476		}
477	}
478}
479
480#[cfg(test)]
481mod test_clock {
482	use super::Timestamp;
483	use std::sync::Arc;
484
485	#[derive(Clone, Debug)]
486	pub(super) struct Clock(Arc<Timestamp>);
487
488	impl Default for Clock {
489		fn default() -> Self {
490			use std::time::{SystemTime, UNIX_EPOCH};
491			// Get our initial time from the real world, but then freeze it
492			Self(Arc::new(
493				SystemTime::now()
494					.duration_since(UNIX_EPOCH)
495					.unwrap()
496					.as_secs()
497					.into(),
498			))
499		}
500	}
501
502	impl Clock {
503		#[tracing::instrument(level = "trace")]
504		pub(super) fn now(&self) -> Timestamp {
505			*self.0
506		}
507
508		#[tracing::instrument(level = "trace")]
509		pub(super) fn timewarp(&mut self, secs: i64) {
510			if let Some(x) = Arc::<Timestamp>::get_mut(&mut self.0) {
511				*x = *x + secs;
512			} else {
513				panic!("Time has no meaning");
514			}
515		}
516	}
517}
518
519#[cfg(not(test))]
520use real_clock::Clock;
521#[cfg(test)]
522use test_clock::Clock;
523
524#[cfg(test)]
525mod tests {
526	use super::*;
527	use crate::generate_key;
528	use std::sync::Once;
529	use tracing_subscriber::{layer::SubscriberExt as _, registry::Registry};
530
531	static INIT: Once = Once::new();
532
533	fn init() {
534		INIT.call_once(|| {
535			let layer = tracing_tree::HierarchicalLayer::default()
536				.with_writer(tracing_subscriber::fmt::TestWriter::new())
537				.with_indent_lines(true)
538				.with_indent_amount(2)
539				.with_targets(true);
540
541			let sub = Registry::default().with(layer);
542			tracing::subscriber::set_global_default(sub).unwrap();
543		});
544	}
545
546	#[test]
547	fn period_calculation() {
548		init();
549		let kc = KeyCache {
550			lifespan: Duration::from_secs(60),
551			backtrack: 0,
552			..KeyCache::default()
553		};
554
555		assert_eq!(
556			Period {
557				first_valid_at: 0.into(),
558				invalid_after: 59.into(),
559			},
560			kc.period(0.into())
561		);
562		assert_eq!(
563			Period {
564				first_valid_at: 0.into(),
565				invalid_after: 59.into(),
566			},
567			kc.period(30.into())
568		);
569		assert_eq!(
570			Period {
571				first_valid_at: 0.into(),
572				invalid_after: 59.into(),
573			},
574			kc.period(59.into())
575		);
576		assert_eq!(
577			Period {
578				first_valid_at: 60.into(),
579				invalid_after: 119.into(),
580			},
581			kc.period(60.into())
582		);
583		assert_eq!(
584			Period {
585				first_valid_at: 1234567860.into(),
586				invalid_after: 1234567919.into()
587			},
588			kc.period(1234567890.into())
589		);
590	}
591
592	#[test]
593	fn previous_period() {
594		init();
595		let kc = KeyCache {
596			lifespan: Duration::from_secs(60),
597			backtrack: 0,
598			..KeyCache::default()
599		};
600
601		assert_eq!(None, kc.period(59.into()).previous());
602		assert_eq!(
603			Some(Period {
604				first_valid_at: 0.into(),
605				invalid_after: 59.into(),
606			}),
607			kc.period(60.into()).previous()
608		);
609	}
610
611	#[test]
612	fn simple_round_trip() {
613		init();
614		let key = generate_key();
615		let rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 0);
616
617		let ciphertext = rsb.encrypt(b"hello, world!", b"test").unwrap();
618
619		assert_eq!(
620			b"hello, world!".to_vec(),
621			rsb.decrypt(&ciphertext, b"test")
622				.expect("encryption failed")
623		);
624	}
625
626	#[test]
627	fn context_matters() {
628		init();
629		let key = generate_key();
630		let rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 0);
631
632		let ciphertext = rsb.encrypt(b"hello, world!", b"context").unwrap();
633
634		let result = rsb.decrypt(&ciphertext, b"a different context");
635		assert!(matches!(result, Err(Error::Decryption)));
636	}
637
638	#[test]
639	fn static_time_old_key() {
640		init();
641		let old_key = generate_key();
642		let old_rsb = RotatingStrongBox::new(
643			old_key.clone(),
644			Vec::<Key>::new(),
645			Duration::new(86400, 0),
646			0,
647		);
648
649		let ciphertext = old_rsb.encrypt(b"hello, world!", b"test").unwrap();
650
651		let new_key = generate_key();
652
653		let rsb = RotatingStrongBox::new(new_key, vec![old_key], Duration::new(86400, 0), 0);
654
655		assert_eq!(
656			b"hello, world!".to_vec(),
657			rsb.decrypt(&ciphertext, b"test")
658				.expect("decryption failed")
659		);
660	}
661
662	#[test]
663	fn no_backtracking_allowed() {
664		init();
665
666		let key = generate_key();
667
668		let mut rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 0);
669
670		let plaintext = b"tasty, tasty plaintext";
671		let ciphertext = rsb.encrypt(&plaintext, b"test").unwrap();
672
673		// Should be able to decrypt what we just encrypted
674		tracing::info!("NOW");
675		assert_eq!(
676			plaintext.to_vec(),
677			rsb.decrypt(&ciphertext, b"test")
678				.expect("decryption failed")
679		);
680
681		// Can't decrypt something we recently encrypted!
682		tracing::info!("NOW+1");
683		rsb.timewarp(60);
684		let result = rsb.decrypt(&ciphertext, b"test");
685		assert!(matches!(result, Err::<Vec<u8>, Error>(Error::Decryption)));
686	}
687
688	#[test]
689	fn the_passing_of_time() {
690		init();
691
692		let key = generate_key();
693
694		let mut rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 3);
695
696		let plaintext = b"some sort of delicious plaintext";
697		let ciphertext = rsb.encrypt(plaintext, b"test").unwrap();
698
699		// Should be able to decrypt what we just encrypted
700		tracing::info!("NOW");
701		assert_eq!(
702			plaintext.to_vec(),
703			rsb.decrypt(&ciphertext, b"test")
704				.expect("decryption failed")
705		);
706
707		// Now let's move into the fuuuutuuuuuuure
708
709		// Can still decrypt what encrypted one time period ago
710		tracing::info!("NOW+1");
711		rsb.timewarp(60);
712		assert_eq!(
713			plaintext.to_vec(),
714			rsb.decrypt(&ciphertext, b"test")
715				.expect("decryption failed")
716		);
717
718		// Two time periods...
719		tracing::info!("NOW+2");
720		rsb.timewarp(60);
721		assert_eq!(
722			plaintext.to_vec(),
723			rsb.decrypt(&ciphertext, b"test")
724				.expect("decryption failed")
725		);
726
727		// Even three time periods!
728		tracing::info!("NOW+3");
729		rsb.timewarp(60);
730		assert_eq!(
731			plaintext.to_vec(),
732			rsb.decrypt(&ciphertext, b"test")
733				.expect("decryption failed")
734		);
735
736		// But not four time periods... that's *right* out
737		tracing::info!("NOW+4");
738		rsb.timewarp(60);
739		let result = rsb.decrypt(&ciphertext, b"test");
740		assert!(matches!(result, Err::<Vec<u8>, Error>(Error::Decryption)));
741	}
742}