rustls_channel_resolver/
lib.rs1use std::{
2 cell::{OnceCell, RefCell},
3 sync::{
4 atomic::{AtomicU64, Ordering},
5 Arc, RwLock,
6 },
7};
8
9use rustls::sign::CertifiedKey;
10
11pub fn channel<const SHARDS: usize>(
17 initial: CertifiedKey,
18) -> (ChannelSender, Arc<ChannelResolver<SHARDS>>) {
19 let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
20
21 let cloned_resolver = Arc::clone(&resolver);
22 let inner = cloned_resolver as Arc<ErasedChannelResolver>;
23
24 (ChannelSender { inner }, resolver)
25}
26
27#[derive(Clone)]
29pub struct ChannelSender {
30 inner: Arc<ErasedChannelResolver>,
31}
32
33mod sealed {
34 use std::sync::{atomic::AtomicU64, Arc, RwLock};
35
36 use rustls::sign::CertifiedKey;
37
38 pub struct ChannelResolverInner<L: ?Sized> {
39 pub(super) locks: L,
40 }
41
42 pub struct Shard {
43 pub(super) generation: AtomicU64,
44 pub(super) lock: RwLock<Arc<CertifiedKey>>,
45 }
46}
47
48pub type ChannelResolver<const SHARDS: usize> =
50 sealed::ChannelResolverInner<[sealed::Shard; SHARDS]>;
51type ErasedChannelResolver = sealed::ChannelResolverInner<[sealed::Shard]>;
52
53thread_local! {
54 static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = const { OnceCell::new() };
55}
56
57impl sealed::Shard {
58 fn new(key: CertifiedKey) -> Self {
59 Self {
60 generation: AtomicU64::new(0),
61 lock: RwLock::new(Arc::new(key)),
62 }
63 }
64
65 fn update(&self, key: CertifiedKey) {
66 {
67 *self.lock.write().unwrap() = Arc::new(key);
68 }
69 self.generation.fetch_add(1, Ordering::AcqRel);
71 }
72
73 fn read(&self) -> Arc<CertifiedKey> {
74 let generation = self.generation.load(Ordering::Acquire);
75
76 let key = LOCAL_KEY.with(|local_key| {
77 local_key.get().and_then(|refcell| {
78 let borrowed = refcell.borrow();
79 if borrowed.0 == generation {
81 Some(Arc::clone(&borrowed.1))
82 } else {
83 None
84 }
85 })
86 });
87
88 if let Some(key) = key {
89 key
90 } else {
91 let key = Arc::clone(&self.lock.read().unwrap());
93
94 LOCAL_KEY.with(|local_key| {
95 let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
96 if guard.borrow().0 != generation {
97 *guard.borrow_mut() = (generation, Arc::clone(&key))
98 }
99 });
100
101 key
102 }
103 }
104}
105
106impl ChannelSender {
107 pub fn update(&self, key: CertifiedKey) {
109 for lock in &self.inner.locks {
110 lock.update(key.clone());
111 }
112 }
113}
114
115impl<const SHARDS: usize> ChannelResolver<SHARDS> {
116 fn new(key: CertifiedKey) -> Self {
117 Self {
118 locks: [(); SHARDS].map(|()| sealed::Shard::new(key.clone())),
119 }
120 }
121
122 #[doc(hidden)]
124 pub fn read(&self) -> Arc<CertifiedKey> {
125 self.locks[rand::random_range(0..SHARDS)].read()
127 }
128}
129
130impl<const SHARDS: usize> std::fmt::Debug for ChannelResolver<SHARDS> {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("ChannelResolver")
133 .field("locks", &format!("[Lock; {SHARDS}]"))
134 .finish()
135 }
136}
137
138impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
139 fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
140 Some(self.read())
141 }
142}