rustls_channel_resolver/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use std::{
    cell::{OnceCell, RefCell},
    sync::{
        atomic::{AtomicU64, Ordering},
        Arc, RwLock,
    },
};

use rand::{thread_rng, Rng};
use rustls::sign::CertifiedKey;

/// Create a new resolver channel. The Sender can update the key, and the Receiver can be
/// registered with a rustls server
///
/// the SHARDS const generic controlls the size of the channel. A larger channel will reduce
/// contention, leading to faster reads but more writes per call to `update`.
pub fn channel<const SHARDS: usize>(
    initial: CertifiedKey,
) -> (ChannelSender, Arc<ChannelResolver<SHARDS>>) {
    let resolver = Arc::new(ChannelResolver::<SHARDS>::new(initial));

    let cloned_resolver = Arc::clone(&resolver);
    let inner = cloned_resolver as Arc<ErasedChannelResolver>;

    (ChannelSender { inner }, resolver)
}

/// The Send half of the channel. This is used for updating the server with new keys
#[derive(Clone)]
pub struct ChannelSender {
    inner: Arc<ErasedChannelResolver>,
}

mod sealed {
    use std::sync::{atomic::AtomicU64, Arc, RwLock};

    use rustls::sign::CertifiedKey;

    pub struct ChannelResolverInner<L: ?Sized> {
        pub(super) locks: L,
    }

    pub struct Shard {
        pub(super) generation: AtomicU64,
        pub(super) lock: RwLock<Arc<CertifiedKey>>,
    }
}

// The Receive half of the channel. This is registerd with rustls to provide the server with keys
pub type ChannelResolver<const SHARDS: usize> =
    sealed::ChannelResolverInner<[sealed::Shard; SHARDS]>;
type ErasedChannelResolver = sealed::ChannelResolverInner<[sealed::Shard]>;

thread_local! {
    static LOCAL_KEY: OnceCell<RefCell<(u64, Arc<CertifiedKey>)>> = OnceCell::new();
}

impl sealed::Shard {
    fn new(key: CertifiedKey) -> Self {
        Self {
            generation: AtomicU64::new(0),
            lock: RwLock::new(Arc::new(key)),
        }
    }

    fn update(&self, key: CertifiedKey) {
        {
            *self.lock.write().unwrap() = Arc::new(key);
        }
        // update generation after lock is released. reduces lock contention with readers
        self.generation.fetch_add(1, Ordering::AcqRel);
    }

    fn read(&self) -> Arc<CertifiedKey> {
        let generation = self.generation.load(Ordering::Acquire);

        let key = LOCAL_KEY.with(|local_key| {
            local_key.get().and_then(|refcell| {
                let borrowed = refcell.borrow();
                // if TLS generation is the same, we can safely return TLS key
                if borrowed.0 == generation {
                    Some(Arc::clone(&borrowed.1))
                } else {
                    None
                }
            })
        });

        if let Some(key) = key {
            key
        } else {
            // slow path, take a read lock and update TLS with new key
            let key = Arc::clone(&self.lock.read().unwrap());

            LOCAL_KEY.with(|local_key| {
                let guard = local_key.get_or_init(|| RefCell::new((generation, Arc::clone(&key))));
                if guard.borrow().0 != generation {
                    *guard.borrow_mut() = (generation, Arc::clone(&key))
                }
            });

            key
        }
    }
}

impl ChannelSender {
    /// Update the key in the channel
    pub fn update(&self, key: CertifiedKey) {
        for lock in &self.inner.locks {
            lock.update(key.clone());
        }
    }
}

impl<const SHARDS: usize> ChannelResolver<SHARDS> {
    fn new(key: CertifiedKey) -> Self {
        Self {
            locks: [(); SHARDS].map(|()| sealed::Shard::new(key.clone())),
        }
    }

    // exposed for benching
    #[doc(hidden)]
    pub fn read(&self) -> Arc<CertifiedKey> {
        // choose random shard to reduce contention. unwrap since slice is always non-empty
        self.locks[thread_rng().gen_range(0..SHARDS)].read()
    }
}

impl<const SHARDS: usize> std::fmt::Debug for ChannelResolver<SHARDS> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ChannelResolver")
            .field("locks", &format!("[Lock; {SHARDS}]"))
            .finish()
    }
}

impl<const SHARDS: usize> rustls::server::ResolvesServerCert for ChannelResolver<SHARDS> {
    fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
        Some(self.read())
    }
}