rustls_channel_resolver/
lib.rsuse std::{
cell::{OnceCell, RefCell},
sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock,
},
};
use rand::{thread_rng, Rng};
use rustls::sign::CertifiedKey;
pub fn channel<const SHARDS: usize>(
initial: CertifiedKey,
) -> (ChannelSender, Arc<ChannelResolver<SHARDS>>) {
let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));
let cloned_resolver = Arc::clone(&resolver);
let inner = cloned_resolver as Arc<ErasedChannelResolver>;
(ChannelSender { inner }, resolver)
}
#[derive(Clone)]
pub struct ChannelSender {
inner: Arc<ErasedChannelResolver>,
}
mod sealed {
use std::sync::{atomic::AtomicU64, Arc, RwLock};
use rustls::sign::CertifiedKey;
pub struct ChannelResolverInner<L: ?Sized> {
pub(super) locks: L,
}
pub struct Shard {
pub(super) generation: AtomicU64,
pub(super) lock: RwLock<Arc<CertifiedKey>>,
}
}
pub type ChannelResolver<const SHARDS: usize> =
sealed::ChannelResolverInner<[sealed::Shard; SHARDS]>;
type ErasedChannelResolver = sealed::ChannelResolverInner<[sealed::Shard]>;
thread_local! {
static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
}
impl sealed::Shard {
fn new(key: CertifiedKey) -> Self {
Self {
generation: AtomicU64::new(0),
lock: RwLock::new(Arc::new(key)),
}
}
fn update(&self, key: CertifiedKey) {
{
*self.lock.write().unwrap() = Arc::new(key);
}
self.generation.fetch_add(1, Ordering::AcqRel);
}
fn read(&self) -> Arc<CertifiedKey> {
let generation = self.generation.load(Ordering::Acquire);
let key = LOCAL_KEY.with(|local_key| {
local_key.get().and_then(|refcell| {
let borrowed = refcell.borrow();
if borrowed.0 == generation {
Some(Arc::clone(&borrowed.1))
} else {
None
}
})
});
if let Some(key) = key {
key
} else {
let key = Arc::clone(&self.lock.read().unwrap());
LOCAL_KEY.with(|local_key| {
let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
if guard.borrow().0 != generation {
*guard.borrow_mut() = (generation, Arc::clone(&key))
}
});
key
}
}
}
impl ChannelSender {
pub fn update(&self, key: CertifiedKey) {
for lock in &self.inner.locks {
lock.update(key.clone());
}
}
}
impl<const SHARDS: usize> ChannelResolver<SHARDS> {
fn new(key: CertifiedKey) -> Self {
Self {
locks: [(); SHARDS].map(|()| sealed::Shard::new(key.clone())),
}
}
#[doc(hidden)]
pub fn read(&self) -> Arc<CertifiedKey> {
self.locks[thread_rng().gen_range(0..SHARDS)].read()
}
}
impl<const SHARDS: usize> std::fmt::Debug for ChannelResolver<SHARDS> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelResolver")
.field("locks", &format!("[Lock; {SHARDS}]"))
.finish()
}
}
impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.read())
}
}