rustls_channel_resolver/
lib.rs

1use std::{
2    cell::{OnceCell, RefCell},
3    sync::{
4        atomic::{AtomicU64, Ordering},
5        Arc, RwLock,
6    },
7};
8
9use rustls::sign::CertifiedKey;
10
11/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
12/// registered with a rustls server
13///
14/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
15/// contention, leading to faster reads but more writes per call to `update`.
16pub fn channel<const SHARDS: usize>(
17    initial: CertifiedKey,
18) -> (ChannelSender, Arc<ChannelResolver<SHARDS>>) {
19    let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
20
21    let cloned_resolver = Arc::clone(&resolver);
22    let inner = cloned_resolver as Arc<ErasedChannelResolver>;
23
24    (ChannelSender { inner }, resolver)
25}
26
27/// The Send half of the channel. This is used for updating the server with new keys
28#[derive(Clone)]
29pub struct ChannelSender {
30    inner: Arc<ErasedChannelResolver>,
31}
32
33mod sealed {
34    use std::sync::{atomic::AtomicU64, Arc, RwLock};
35
36    use rustls::sign::CertifiedKey;
37
38    pub struct ChannelResolverInner<L: ?Sized> {
39        pub(super) locks: L,
40    }
41
42    pub struct Shard {
43        pub(super) generation: AtomicU64,
44        pub(super) lock: RwLock<Arc<CertifiedKey>>,
45    }
46}
47
48// The Receive half of the channel. This is registerd with rustls to provide the server with keys
49pub type ChannelResolver<const SHARDS: usize> =
50    sealed::ChannelResolverInner<[sealed::Shard; SHARDS]>;
51type ErasedChannelResolver = sealed::ChannelResolverInner<[sealed::Shard]>;
52
53thread_local! {
54    static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = const { OnceCell::new() };
55}
56
57impl sealed::Shard {
58    fn new(key: CertifiedKey) -> Self {
59        Self {
60            generation: AtomicU64::new(0),
61            lock: RwLock::new(Arc::new(key)),
62        }
63    }
64
65    fn update(&self, key: CertifiedKey) {
66        {
67            *self.lock.write().unwrap() = Arc::new(key);
68        }
69        // update generation after lock is released. reduces lock contention with readers
70        self.generation.fetch_add(1, Ordering::AcqRel);
71    }
72
73    fn read(&self) -> Arc<CertifiedKey> {
74        let generation = self.generation.load(Ordering::Acquire);
75
76        let key = LOCAL_KEY.with(|local_key| {
77            local_key.get().and_then(|refcell| {
78                let borrowed = refcell.borrow();
79                // if TLS generation is the same, we can safely return TLS key
80                if borrowed.0 == generation {
81                    Some(Arc::clone(&borrowed.1))
82                } else {
83                    None
84                }
85            })
86        });
87
88        if let Some(key) = key {
89            key
90        } else {
91            // slow path, take a read lock and update TLS with new key
92            let key = Arc::clone(&self.lock.read().unwrap());
93
94            LOCAL_KEY.with(|local_key| {
95                let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
96                if guard.borrow().0 != generation {
97                    *guard.borrow_mut() = (generation, Arc::clone(&key))
98                }
99            });
100
101            key
102        }
103    }
104}
105
106impl ChannelSender {
107    /// Update the key in the channel
108    pub fn update(&self, key: CertifiedKey) {
109        for lock in &self.inner.locks {
110            lock.update(key.clone());
111        }
112    }
113}
114
115impl<const SHARDS: usize> ChannelResolver<SHARDS> {
116    fn new(key: CertifiedKey) -> Self {
117        Self {
118            locks: [(); SHARDS].map(|()| sealed::Shard::new(key.clone())),
119        }
120    }
121
122    // exposed for benching
123    #[doc(hidden)]
124    pub fn read(&self) -> Arc<CertifiedKey> {
125        // choose random shard to reduce contention. unwrap since slice is always non-empty
126        self.locks[rand::random_range(0..SHARDS)].read()
127    }
128}
129
130impl<const SHARDS: usize> std::fmt::Debug for ChannelResolver<SHARDS> {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("ChannelResolver")
133            .field("locks", &format!("[Lock; {SHARDS}]"))
134            .finish()
135    }
136}
137
138impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
139    fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
140        Some(self.read())
141    }
142}