1use parking_lot::RwLock;
2use std::{
3 collections::{BTreeMap, BTreeSet},
4 fmt::Debug,
5 sync::Arc,
6 time::Duration,
7};
8
9use super::{Ciphertext, Error, Key, KeyId, StaticStrongBox, StrongBox, kdf};
10
11#[derive(Clone, Debug)]
46pub struct RotatingStrongBox {
47 time: Clock,
50
51 key_cache: Arc<RwLock<KeyCache>>,
52}
53
54impl std::panic::UnwindSafe for RotatingStrongBox {}
56
57impl RotatingStrongBox {
58 #[tracing::instrument(level = "trace")]
59 pub(super) fn new(
60 enc_key: Key,
61 dec_keys: Vec<Key>,
62 lifespan: Duration,
63 backtrack: u16,
64 ) -> Self {
65 Self {
66 #[cfg(not(test))]
67 time: Clock,
68 #[cfg(test)]
69 time: Clock::default(),
70 key_cache: Arc::new(RwLock::new(KeyCache {
71 lifespan,
72 backtrack,
73 root_encryption_key: enc_key,
74 root_decryption_keys: dec_keys,
75 current_encryptor: CachedStrongBox::new(Box::new([0u8; 32]).into(), Timestamp(0)),
76 decryptor_cache: BTreeMap::default(),
77 decryptor_validities: BTreeSet::default(),
78 cache_invalid_at: Timestamp::default(),
79 })),
80 }
81 }
82
83 #[tracing::instrument(level = "trace", skip(self, ciphertext))]
84 fn try_decrypt_with(&self, ciphertext: &Ciphertext, ctx: &[u8]) -> Result<Vec<u8>, Error> {
85 let key_cache = self.key_cache.read_arc();
86
87 if let Some(cached_strongbox) = key_cache.decryptor_cache.get(&ciphertext.key_id) {
88 if cached_strongbox.is_expired(self.time.now()) {
89 tracing::debug!(key_id=%ciphertext.key_id, "Key expired");
90 Err(Error::Decryption)
91 } else {
92 cached_strongbox
93 .strong_box
94 .decrypt_ciphertext(ciphertext, ctx)
95 }
96 } else {
97 tracing::debug!(key_id=%ciphertext.key_id, "Key not found");
98 Err(Error::Decryption)
99 }
100 }
101
102 #[cfg(test)]
103 fn timewarp(&mut self, secs: i64) {
104 self.time.timewarp(secs)
105 }
106}
107
108impl StrongBox for RotatingStrongBox {
109 #[tracing::instrument(level = "debug", skip(plaintext))]
110 fn encrypt(
111 &self,
112 plaintext: impl AsRef<[u8]>,
113 ctx: impl AsRef<[u8]> + Debug,
114 ) -> Result<Vec<u8>, Error> {
115 let mut key_cache = self.key_cache.write_arc();
116 key_cache
117 .current_encryptor(self.time.now())
118 .strong_box
119 .encrypt(plaintext.as_ref(), ctx.as_ref())
120 }
121
122 #[tracing::instrument(level = "debug", skip(ciphertext))]
123 fn decrypt(
124 &self,
125 ciphertext: impl AsRef<[u8]>,
126 ctx: impl AsRef<[u8]> + Debug,
127 ) -> Result<Vec<u8>, Error> {
128 fn inner(
129 this: &RotatingStrongBox,
130 ciphertext: &[u8],
131 ctx: &[u8],
132 ) -> Result<Vec<u8>, Error> {
133 let ciphertext = Ciphertext::try_from(ciphertext)?;
134
135 if let Ok(plaintext) = this.try_decrypt_with(&ciphertext, ctx.as_ref()) {
136 Ok(plaintext)
137 } else {
138 let mut key_cache = this.key_cache.write_arc();
139 key_cache.refresh_cache(this.time.now());
140 drop(key_cache);
143 this.try_decrypt_with(&ciphertext, ctx.as_ref())
144 }
145 }
146 inner(self, ciphertext.as_ref(), ctx.as_ref())
147 }
148}
149
150#[derive(Clone, Copy, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
151#[repr(transparent)]
152struct Timestamp(u64);
153
154impl std::ops::Deref for Timestamp {
155 type Target = u64;
156
157 fn deref(&self) -> &u64 {
158 &self.0
159 }
160}
161
162impl std::fmt::Display for Timestamp {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.write_fmt(format_args!("{}", self.0))
165 }
166}
167
168impl From<u64> for Timestamp {
169 fn from(t: u64) -> Self {
170 Self(t)
171 }
172}
173
174impl std::ops::Add<u64> for Timestamp {
175 type Output = Timestamp;
176
177 fn add(self, t: u64) -> Self {
178 Self(self.0 + t)
179 }
180}
181
182impl std::ops::Add<i64> for Timestamp {
183 type Output = Timestamp;
184
185 fn add(self, t: i64) -> Self {
186 Self(self.0.checked_add_signed(t).unwrap())
187 }
188}
189
190impl std::ops::Add<Timestamp> for Timestamp {
191 type Output = Timestamp;
192
193 fn add(self, t: Timestamp) -> Self {
194 Self(self.0 + t.0)
195 }
196}
197
198impl std::ops::Add<Duration> for Timestamp {
199 type Output = Timestamp;
200
201 fn add(self, t: Duration) -> Self {
202 Self(self.0 + t.as_secs())
203 }
204}
205
206impl std::ops::Sub<u64> for Timestamp {
207 type Output = Timestamp;
208
209 fn sub(self, t: u64) -> Self {
210 Self(self.0 - t)
211 }
212}
213
214impl std::ops::Sub<Timestamp> for Timestamp {
215 type Output = Timestamp;
216
217 fn sub(self, t: Timestamp) -> Self {
218 Self(self.0 - t.0)
219 }
220}
221
222impl std::ops::Sub<Duration> for Timestamp {
223 type Output = Timestamp;
224
225 fn sub(self, t: Duration) -> Self {
226 Self(self.0 - t.as_secs())
227 }
228}
229
230impl std::ops::Mul<u16> for Timestamp {
231 type Output = Timestamp;
232
233 fn mul(self, n: u16) -> Self {
234 Self(self.0 * n as u64)
235 }
236}
237
238impl std::ops::Rem<Duration> for Timestamp {
239 type Output = Timestamp;
240
241 fn rem(self, t: Duration) -> Self {
242 Self(self.0 % t.as_secs())
243 }
244}
245
246#[derive(Debug)]
247struct KeyCache {
248 lifespan: Duration,
249 backtrack: u16,
250
251 root_encryption_key: Key,
252 root_decryption_keys: Vec<Key>,
253
254 current_encryptor: CachedStrongBox,
255
256 decryptor_cache: BTreeMap<KeyId, CachedStrongBox>,
257 decryptor_validities: BTreeSet<(Timestamp, KeyId)>,
260 cache_invalid_at: Timestamp,
262}
263
264#[cfg(test)]
265impl Default for KeyCache {
266 fn default() -> Self {
267 Self {
268 lifespan: Duration::from_secs(0),
269 backtrack: 0,
270 root_encryption_key: Box::new([0; 32]).into(),
271 root_decryption_keys: Vec::default(),
272 current_encryptor: CachedStrongBox::new(Box::new([0u8; 32]).into(), Timestamp(0)),
273 decryptor_cache: BTreeMap::default(),
274 decryptor_validities: BTreeSet::default(),
275 cache_invalid_at: Timestamp::default(),
276 }
277 }
278}
279
280impl KeyCache {
281 fn current_encryptor(&mut self, t: Timestamp) -> &CachedStrongBox {
282 if self.current_encryptor.is_expired(t) {
283 let current_period = self.period(t);
284 self.current_encryptor = CachedStrongBox::new(
285 self.derive_key(&self.root_encryption_key, ¤t_period),
286 current_period.invalid_after,
287 );
288 }
289
290 &self.current_encryptor
291 }
292
293 #[tracing::instrument(level = "trace", skip(self))]
294 fn derive_key(&self, root_key: &Key, period: &Period) -> Key {
295 let mut context = b"rotation::".to_vec();
296 context.extend_from_slice(&period.to_bytes());
297
298 kdf::derive_key(root_key, &context)
299 }
300
301 #[tracing::instrument(level = "trace", skip(self))]
302 fn refresh_cache(&mut self, as_at: Timestamp) {
303 if as_at < self.cache_invalid_at {
304 return;
305 }
306
307 let mut current_period = self.period(as_at);
308
309 self.purge_entries_before(current_period.first_valid_at);
312
313 for b in 0..=(self.backtrack) {
314 let trial_key_id = crate::key_id(
316 &self.derive_key(
317 self.root_decryption_keys
318 .first()
319 .expect("caller should have verified we have decryption keys!"),
320 ¤t_period,
321 ),
322 );
323
324 if self.decryptor_cache.contains_key(&trial_key_id) {
325 tracing::debug!("Cache refresh complete due to finding previous generation key");
329 return;
330 }
331
332 let invalid_after = current_period.invalid_after
333 + (self.backtrack - b) as u64 * self.lifespan.as_secs();
334
335 for key in &self.root_decryption_keys {
337 let key = self.derive_key(key, ¤t_period);
338 let key_id = crate::key_id(&key);
339 let strongbox = CachedStrongBox::new(key, invalid_after);
340
341 tracing::debug!(%key_id, %invalid_after,
342 "Adding key to cache",
343 );
344
345 self.decryptor_cache.insert(key_id, strongbox);
346 self.decryptor_validities.insert((invalid_after, key_id));
347 }
348
349 if let Some(previous_period) = current_period.previous() {
350 current_period = previous_period;
351 } else {
352 tracing::debug!("Epoch reached");
354 return;
355 }
356 }
357
358 tracing::debug!("Cache fully populated");
359 }
360
361 #[tracing::instrument(level = "trace", skip(self))]
362 fn purge_entries_before(&mut self, t: Timestamp) {
363 loop {
364 let oldest_entry = self.oldest_cached_decryptor();
365
366 if let Some((expiry, key_id)) = oldest_entry {
367 if expiry < t {
368 tracing::debug!(%key_id, "Removing expired key");
369 self.decryptor_validities.remove(&(expiry, key_id));
370 self.decryptor_cache.remove(&key_id);
371 } else {
372 return;
374 }
375 } else {
376 return;
378 }
379 }
380 }
381
382 #[tracing::instrument(level = "trace", skip(self))]
383 fn oldest_cached_decryptor(&self) -> Option<(Timestamp, KeyId)> {
384 self.decryptor_validities.first().copied()
385 }
386
387 #[tracing::instrument(level = "trace", skip(self))]
388 fn period(&self, at: Timestamp) -> Period {
389 let first_valid_at = at - (at % self.lifespan);
390 let invalid_after = first_valid_at + self.lifespan - 1;
391
392 Period {
393 first_valid_at,
394 invalid_after,
395 }
396 }
397}
398
399#[derive(Clone, Debug, PartialEq)]
401struct Period {
402 first_valid_at: Timestamp,
403 invalid_after: Timestamp,
404}
405
406impl Period {
407 #[tracing::instrument(level = "trace", skip(self))]
408 fn to_bytes(&self) -> Vec<u8> {
409 let mut bytes = vec![];
410
411 bytes.extend_from_slice(&self.first_valid_at.to_be_bytes());
412 bytes.extend_from_slice(&self.invalid_after.to_be_bytes());
413
414 bytes
415 }
416
417 #[tracing::instrument(level = "trace")]
418 fn previous(&self) -> Option<Period> {
419 self.back_by(1)
420 }
421
422 #[tracing::instrument(level = "trace")]
423 fn back_by(&self, n: u16) -> Option<Period> {
424 let d = (self.invalid_after - self.first_valid_at + 1u64) * n;
425
426 if self.first_valid_at < d {
427 None
429 } else {
430 Some(Period {
431 first_valid_at: self.first_valid_at - d,
432 invalid_after: self.invalid_after - d,
433 })
434 }
435 }
436}
437
438#[derive(Debug)]
439struct CachedStrongBox {
440 invalid_after: Timestamp,
441 strong_box: StaticStrongBox,
442}
443
444impl CachedStrongBox {
445 #[tracing::instrument(level = "trace", name = "CachedStrongBox::new")]
446 fn new(key: Key, invalid_after: Timestamp) -> Self {
447 Self {
448 invalid_after,
449 strong_box: StaticStrongBox::new(key.clone(), [key]),
450 }
451 }
452
453 #[tracing::instrument(level = "trace")]
454 fn is_expired(&self, now: Timestamp) -> bool {
455 now > self.invalid_after
456 }
457}
458
459#[cfg(not(test))]
460mod real_clock {
461 use super::Timestamp;
462 use std::time::{SystemTime, UNIX_EPOCH};
463
464 #[derive(Clone, Debug, Default)]
465 pub(super) struct Clock;
466
467 impl Clock {
468 #[tracing::instrument(level = "trace")]
469 pub(super) fn now(&self) -> Timestamp {
470 Timestamp(
471 SystemTime::now()
472 .duration_since(UNIX_EPOCH)
473 .unwrap()
474 .as_secs(),
475 )
476 }
477 }
478}
479
480#[cfg(test)]
481mod test_clock {
482 use super::Timestamp;
483 use std::sync::Arc;
484
485 #[derive(Clone, Debug)]
486 pub(super) struct Clock(Arc<Timestamp>);
487
488 impl Default for Clock {
489 fn default() -> Self {
490 use std::time::{SystemTime, UNIX_EPOCH};
491 Self(Arc::new(
493 SystemTime::now()
494 .duration_since(UNIX_EPOCH)
495 .unwrap()
496 .as_secs()
497 .into(),
498 ))
499 }
500 }
501
502 impl Clock {
503 #[tracing::instrument(level = "trace")]
504 pub(super) fn now(&self) -> Timestamp {
505 *self.0
506 }
507
508 #[tracing::instrument(level = "trace")]
509 pub(super) fn timewarp(&mut self, secs: i64) {
510 if let Some(x) = Arc::<Timestamp>::get_mut(&mut self.0) {
511 *x = *x + secs;
512 } else {
513 panic!("Time has no meaning");
514 }
515 }
516 }
517}
518
519#[cfg(not(test))]
520use real_clock::Clock;
521#[cfg(test)]
522use test_clock::Clock;
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::generate_key;
528 use std::sync::Once;
529 use tracing_subscriber::{layer::SubscriberExt as _, registry::Registry};
530
531 static INIT: Once = Once::new();
532
533 fn init() {
534 INIT.call_once(|| {
535 let layer = tracing_tree::HierarchicalLayer::default()
536 .with_writer(tracing_subscriber::fmt::TestWriter::new())
537 .with_indent_lines(true)
538 .with_indent_amount(2)
539 .with_targets(true);
540
541 let sub = Registry::default().with(layer);
542 tracing::subscriber::set_global_default(sub).unwrap();
543 });
544 }
545
546 #[test]
547 fn period_calculation() {
548 init();
549 let kc = KeyCache {
550 lifespan: Duration::from_secs(60),
551 backtrack: 0,
552 ..KeyCache::default()
553 };
554
555 assert_eq!(
556 Period {
557 first_valid_at: 0.into(),
558 invalid_after: 59.into(),
559 },
560 kc.period(0.into())
561 );
562 assert_eq!(
563 Period {
564 first_valid_at: 0.into(),
565 invalid_after: 59.into(),
566 },
567 kc.period(30.into())
568 );
569 assert_eq!(
570 Period {
571 first_valid_at: 0.into(),
572 invalid_after: 59.into(),
573 },
574 kc.period(59.into())
575 );
576 assert_eq!(
577 Period {
578 first_valid_at: 60.into(),
579 invalid_after: 119.into(),
580 },
581 kc.period(60.into())
582 );
583 assert_eq!(
584 Period {
585 first_valid_at: 1234567860.into(),
586 invalid_after: 1234567919.into()
587 },
588 kc.period(1234567890.into())
589 );
590 }
591
592 #[test]
593 fn previous_period() {
594 init();
595 let kc = KeyCache {
596 lifespan: Duration::from_secs(60),
597 backtrack: 0,
598 ..KeyCache::default()
599 };
600
601 assert_eq!(None, kc.period(59.into()).previous());
602 assert_eq!(
603 Some(Period {
604 first_valid_at: 0.into(),
605 invalid_after: 59.into(),
606 }),
607 kc.period(60.into()).previous()
608 );
609 }
610
611 #[test]
612 fn simple_round_trip() {
613 init();
614 let key = generate_key();
615 let rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 0);
616
617 let ciphertext = rsb.encrypt(b"hello, world!", b"test").unwrap();
618
619 assert_eq!(
620 b"hello, world!".to_vec(),
621 rsb.decrypt(&ciphertext, b"test")
622 .expect("encryption failed")
623 );
624 }
625
626 #[test]
627 fn context_matters() {
628 init();
629 let key = generate_key();
630 let rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 0);
631
632 let ciphertext = rsb.encrypt(b"hello, world!", b"context").unwrap();
633
634 let result = rsb.decrypt(&ciphertext, b"a different context");
635 assert!(matches!(result, Err(Error::Decryption)));
636 }
637
638 #[test]
639 fn static_time_old_key() {
640 init();
641 let old_key = generate_key();
642 let old_rsb = RotatingStrongBox::new(
643 old_key.clone(),
644 Vec::<Key>::new(),
645 Duration::new(86400, 0),
646 0,
647 );
648
649 let ciphertext = old_rsb.encrypt(b"hello, world!", b"test").unwrap();
650
651 let new_key = generate_key();
652
653 let rsb = RotatingStrongBox::new(new_key, vec![old_key], Duration::new(86400, 0), 0);
654
655 assert_eq!(
656 b"hello, world!".to_vec(),
657 rsb.decrypt(&ciphertext, b"test")
658 .expect("decryption failed")
659 );
660 }
661
662 #[test]
663 fn no_backtracking_allowed() {
664 init();
665
666 let key = generate_key();
667
668 let mut rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 0);
669
670 let plaintext = b"tasty, tasty plaintext";
671 let ciphertext = rsb.encrypt(&plaintext, b"test").unwrap();
672
673 tracing::info!("NOW");
675 assert_eq!(
676 plaintext.to_vec(),
677 rsb.decrypt(&ciphertext, b"test")
678 .expect("decryption failed")
679 );
680
681 tracing::info!("NOW+1");
683 rsb.timewarp(60);
684 let result = rsb.decrypt(&ciphertext, b"test");
685 assert!(matches!(result, Err::<Vec<u8>, Error>(Error::Decryption)));
686 }
687
688 #[test]
689 fn the_passing_of_time() {
690 init();
691
692 let key = generate_key();
693
694 let mut rsb = RotatingStrongBox::new(key.clone(), vec![key], Duration::new(60, 0), 3);
695
696 let plaintext = b"some sort of delicious plaintext";
697 let ciphertext = rsb.encrypt(plaintext, b"test").unwrap();
698
699 tracing::info!("NOW");
701 assert_eq!(
702 plaintext.to_vec(),
703 rsb.decrypt(&ciphertext, b"test")
704 .expect("decryption failed")
705 );
706
707 tracing::info!("NOW+1");
711 rsb.timewarp(60);
712 assert_eq!(
713 plaintext.to_vec(),
714 rsb.decrypt(&ciphertext, b"test")
715 .expect("decryption failed")
716 );
717
718 tracing::info!("NOW+2");
720 rsb.timewarp(60);
721 assert_eq!(
722 plaintext.to_vec(),
723 rsb.decrypt(&ciphertext, b"test")
724 .expect("decryption failed")
725 );
726
727 tracing::info!("NOW+3");
729 rsb.timewarp(60);
730 assert_eq!(
731 plaintext.to_vec(),
732 rsb.decrypt(&ciphertext, b"test")
733 .expect("decryption failed")
734 );
735
736 tracing::info!("NOW+4");
738 rsb.timewarp(60);
739 let result = rsb.decrypt(&ciphertext, b"test");
740 assert!(matches!(result, Err::<Vec<u8>, Error>(Error::Decryption)));
741 }
742}