rustls_config_stream/
server.rs

1// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2
3use std::{
4    sync::{
5        Arc,
6        atomic::{AtomicBool, Ordering},
7    },
8    time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use rustls::{ServerConfig, server::VerifierBuilderError};
13use thiserror::Error;
14use tokio::time::sleep;
15use tokio_stream::{Stream, StreamExt};
16
17#[cfg(feature = "tracing")]
18use tracing::{debug, error, info};
19
20/// Errors that can occur while building or consuming a server-config stream.
21///
22/// These represent failures either from the user-provided stream/builder
23/// or from [`rustls`] itself.
24#[derive(Debug, Error)]
25pub enum ServerConfigStreamError {
26    /// The underlying stream produced an error.
27    ///
28    /// This is used to wrap arbitrary stream provider errors.
29    #[error("stream provider error")]
30    StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
31
32    /// The stream completed without yielding an initial [`ServerConfig`].
33    ///
34    /// [`ServerConfigProvider::start`] requires at least one item to seed
35    /// the provider; otherwise startup fails with this error.
36    #[error("empty stream")]
37    EmptyStream,
38
39    /// The builder failed to construct a stream.
40    ///
41    /// The provider will surface this when initial construction fails.
42    #[error("could not build stream")]
43    StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
44
45    /// Error originating from [`rustls`] certificate verifier construction.
46    #[error("cert verifier builder error")]
47    VerifierBuilderError(#[from] VerifierBuilderError),
48
49    /// The builder/stream did not provide a [`rustls::sign::CertifiedKey`]
50    #[error("missing server certified key")]
51    MissingCertifiedKey,
52
53    /// The builder/stream did not provide any root certificates resulting in an empty [`rustls::RootCertStore`]
54    #[error("missing root certificates")]
55    MissingRoots,
56
57    /// Wrapper for any [`rustls::Error`] error.
58    #[error("rustls error")]
59    RustlsError(#[from] rustls::Error),
60}
61
62/// A factory for producing a stream of [`rustls::ServerConfig`].
63///
64/// Implement this trait to define how your application sources TLS configs
65/// (e.g., file watchers, secret managers, pull-from-API).
66///
67/// The returned stream should yield *complete* [`ServerConfig`] values. Each
68/// item replaces the provider's current config atomically (via [`ArcSwap`]).
69///
70/// # Contract
71/// - [`build()`](ServerConfigStreamBuilder::build) should return a stream that eventually yields at least one
72///   [`ServerConfig`] during initial startup. If it doesn't, startup will fail
73///   with [`ServerConfigStreamError::EmptyStream`].
74/// - On stream failure, the provider will call [`build()`](ServerConfigStreamBuilder::build) again with backoff.
75/// - Items from the stream should be independent [`Arc<ServerConfig>`] values.
76///
77/// # Examples
78/// ```no_run
79/// use std::sync::Arc;
80/// use rustls::ServerConfig;
81/// use tokio_stream::{Stream, wrappers::ReceiverStream};
82///
83/// struct MyConfigProvider;
84///
85/// impl ServerConfigStreamBuilder for MyConfigProvider {
86///     type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
87///
88///     async fn build(
89///         &mut self,
90///     ) -> Result<Self::ConfigStream, ServerConfigStreamError> {
91///         // Construct a stream that yields ServerConfig updates.
92///         // See the SPIFFE implementation in [`rustls-spiffe`] for a full example.
93///         unimplemented!()
94///     }
95/// }
96/// ```
97pub trait ServerConfigStreamBuilder {
98    /// The stream type produced by this builder.
99    ///
100    /// Each item is either a fresh [`ServerConfig`] or an error explaining why
101    /// the update failed.
102    type ConfigStream: Stream<Item = Result<Arc<ServerConfig>, ServerConfigStreamError>>
103        + Send
104        + Sync
105        + Unpin
106        + 'static;
107
108    /// Asynchronously construct a new configuration stream.
109    ///
110    /// The provider will:
111    /// - call this once during startup to obtain the initial stream,
112    /// - read the *first* config to seed its state,
113    /// - continue to poll the provided stream for new configs
114    /// - upon stream failure or completion, call it again with
115    ///   exponential backoff until a new stream is available.
116    fn build(
117        &mut self,
118    ) -> impl std::future::Future<Output = Result<Self::ConfigStream, ServerConfigStreamError>> + Send;
119}
120
121/// Holds the current [`ServerConfig`] and refreshes it from an async stream.
122///
123/// Internally uses [`ArcSwap<ServerConfig>`] to provide lock-free, atomic swaps
124/// of the active TLS configuration. Call [`get_config`](Self::get_config) to
125/// obtain an [`Arc<ServerConfig>`] for acceptors or handshakes.
126///
127/// Liveness of the underlying stream can be checked via
128/// [`stream_healthy`](Self::stream_healthy).
129///
130/// # Concurrency
131/// Reads [`get_config()`](Self::get_config) are wait-free and do not block updates.
132/// Updates occur on a background task that listens to the user-provided stream.
133///
134/// # Backoff & Recovery
135/// When the stream ends or errors, the provider:
136/// - Marks itself unhealthy,
137/// - Rebuilds the stream via the builder,
138/// - Retries with exponential backoff starting at 10ms and capping at 10s,
139/// - Resets backoff after a successful re-establishment.
140///
141/// # Examples
142/// ```no_run
143/// let config_stream_builder =
144///     SpiffeServerConfigStream::builder(vec!["example.org".try_into().unwrap()]);
145/// let config_provider = ServerConfigProvider::start(config_stream_builder)
146///     .await
147///     .unwrap();
148/// let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
149///     .await
150///     .unwrap();
151///
152/// loop {
153///     let (stream, _) = listener.accept().await.unwrap();
154///
155///     let acceptor =
156///         tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
157///     tokio::pin!(acceptor);
158///
159///     let config_provider = config_provider.clone();
160///     match acceptor.as_mut().await {
161///         Ok(start) => {
162///             tokio::spawn(async move {
163///                 if !config_provider.stream_healthy() {
164///                     warn!(
165///                         "config provider does not have healthy stream; TLS config may be out of date"
166///                     );
167///                 }
168///                 let config = config_provider.get_config();
169///                 match start.into_stream(config).await {
170///                     Ok(stream) => { /* serve some app (e.g. hyper, tower, axum) */ },
171///                     Err(err) => { /* handle error */ }
172///                 }
173///             })
174///         }
175///         Err(_) => { /* handle error */ }
176///     }
177/// }
178/// ```
179pub struct ServerConfigProvider {
180    /// The current, atomically-swappable server configuration.
181    inner: ArcSwap<ServerConfig>,
182
183    /// Health flag for the underlying stream (true = healthy).
184    stream_healthy: AtomicBool,
185}
186
187impl ServerConfigProvider {
188    /// Initializes the provider and spawn the background refresh task.
189    ///
190    /// This awaits the first item from the builder's stream to seed the
191    /// internal configuration. It then spawns a task that continuously reads
192    /// subsequent updates, atomically swapping them into place.
193    ///
194    /// On stream failure or completion, the task attempts to rebuild the
195    /// stream using exponential backoff (initial 10ms, max 10s, doubling).
196    ///
197    /// Returns an [`Arc<ServerConfigProvider>`]
198    ///
199    /// # Errors
200    /// - [`ServerConfigStreamError::EmptyStream`]: the initial stream yielded no item.
201    /// - [`ServerConfigStreamError::StreamBuilderError`]: building the stream failed.
202    /// - [`ServerConfigStreamError`] variants wrapping errors from your builder or `rustls`.
203    pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ServerConfigStreamError>
204    where
205        B: ServerConfigStreamBuilder + Send + 'static,
206    {
207        let mut stream = builder.build().await?;
208        let initial = stream
209            .next()
210            .await
211            .ok_or(ServerConfigStreamError::EmptyStream)??;
212        let this = Arc::new(Self {
213            inner: ArcSwap::from(initial),
214            stream_healthy: AtomicBool::new(true),
215        });
216        let ret = this.clone();
217
218        tokio::spawn(async move {
219            let initial_delay = Duration::from_millis(10);
220            let mut delay = initial_delay;
221            let max_delay = Duration::from_secs(10);
222            loop {
223                match stream.next().await {
224                    Some(Ok(server_config)) => {
225                        this.inner.store(server_config);
226
227                        #[cfg(feature = "tracing")]
228                        debug!(name: "server_config_provider", "stored updated server config from stream");
229                    }
230                    Some(Err(_)) | None => {
231                        this.stream_healthy.store(false, Ordering::Relaxed);
232
233                        #[cfg(feature = "tracing")]
234                        error!("config stream returned error or none, trying to build new stream");
235
236                        match builder.build().await {
237                            Ok(s) => {
238                                this.stream_healthy.store(true, Ordering::Relaxed);
239                                delay = initial_delay;
240                                stream = s;
241
242                                #[cfg(feature = "tracing")]
243                                info!(name: "server_config_provider", "reestablished server config stream");
244                            }
245                            Err(err) => {
246                                #[cfg(feature = "tracing")]
247                                error!(name: "server_config_provider", retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish server config stream");
248
249                                sleep(delay).await;
250                                delay = (delay * 2).min(max_delay);
251                            }
252                        };
253                    }
254                }
255            }
256        });
257        Ok(ret)
258    }
259
260    /// Returns whether the stream is currently healthy.
261    ///
262    /// This flag is set to `false` when the stream errors or ends, and set
263    /// back to `true` after a successful rebuild.
264    pub fn stream_healthy(&self) -> bool {
265        self.stream_healthy.load(Ordering::Relaxed)
266    }
267
268    /// Get the current [`ServerConfig`].
269    ///
270    /// This is a cheap, lock-free read that loads the internal [`ArcSwap<ServerConfig>`] into an [`Arc<ServerConfig>`]
271    /// Callers can hold onto the returned [`Arc<ServerConfig>`] as long as
272    /// needed; updates will affect future calls, not the already-held value.
273    pub fn get_config(&self) -> Arc<ServerConfig> {
274        self.inner.load_full()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use std::{
281        collections::VecDeque,
282        sync::{
283            Arc,
284            atomic::{AtomicUsize, Ordering},
285        },
286    };
287
288    use rustls::ServerConfig;
289    use thiserror::Error;
290    use tokio::sync::{Mutex, mpsc};
291    use tokio_stream::wrappers::ReceiverStream;
292
293    use crate::{ServerConfigProvider, ServerConfigStreamBuilder, ServerConfigStreamError};
294
295    #[derive(Error, Debug)]
296    struct MockError(&'static str);
297    impl std::fmt::Display for MockError {
298        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299            f.write_str(self.0)
300        }
301    }
302
303    #[derive(Debug)]
304    struct NoopResolver;
305    impl rustls::server::ResolvesServerCert for NoopResolver {
306        fn resolve(
307            &self,
308            _client_hello: rustls::server::ClientHello<'_>,
309        ) -> Option<Arc<rustls::sign::CertifiedKey>> {
310            None
311        }
312    }
313
314    fn empty_server_config() -> Arc<ServerConfig> {
315        Arc::from(
316            ServerConfig::builder()
317                .with_no_client_auth()
318                .with_cert_resolver(Arc::from(NoopResolver)),
319        )
320    }
321
322    #[derive(Debug)]
323    struct MockServerConfigStreamBuilder {
324        streams:
325            Mutex<VecDeque<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>>,
326        builds: Arc<AtomicUsize>,
327    }
328
329    impl MockServerConfigStreamBuilder {
330        fn new(
331            streams: Vec<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>,
332        ) -> Self {
333            let builds = Arc::from(AtomicUsize::new(0));
334            let streams = Mutex::from(VecDeque::from(streams));
335            Self { streams, builds }
336        }
337    }
338
339    impl ServerConfigStreamBuilder for MockServerConfigStreamBuilder {
340        type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
341
342        async fn build(&mut self) -> Result<Self::ConfigStream, ServerConfigStreamError> {
343            self.builds.fetch_add(1, Ordering::SeqCst);
344            let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
345                ServerConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
346            })?;
347            Ok(ReceiverStream::new(rx))
348        }
349    }
350
351    #[tokio::test]
352    async fn start_fails_given_initial_stream_build_failure() {
353        let builder = MockServerConfigStreamBuilder::new(vec![]);
354
355        let res = ServerConfigProvider::start(builder).await;
356        match res {
357            Err(ServerConfigStreamError::StreamBuilderError(_)) => { /* test pass */ }
358            _ => panic!("expected ServerConfigStreamError::EmptyStream"),
359        }
360    }
361
362    #[tokio::test]
363    async fn start_fails_when_stream_is_empty() {
364        let (tx, rx) = mpsc::channel(1);
365
366        // drop tx so stream returns Poll::Ready(None)
367        std::mem::drop(tx);
368
369        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
370
371        let res = ServerConfigProvider::start(builder).await;
372        match res {
373            Err(ServerConfigStreamError::EmptyStream) => { /* test pass */ }
374            _ => panic!("expected ServerConfigStreamError::EmptyStream"),
375        }
376    }
377
378    #[tokio::test]
379    async fn start_fails_when_first_result_is_err() {
380        let (tx, rx) = mpsc::channel(1);
381        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
382
383        tx.send(Err(ServerConfigStreamError::StreamError(
384            MockError("fake error").into(),
385        )))
386        .await
387        .unwrap();
388
389        let res = ServerConfigProvider::start(builder).await;
390        match res {
391            Err(ServerConfigStreamError::StreamError(err)) => {
392                assert_eq!(err.to_string(), "fake error");
393            }
394            _ => panic!("expected ServerConfigStreamError::EmptyStream"),
395        }
396    }
397
398    #[tokio::test]
399    async fn start_and_initial_config_is_loaded() {
400        let (tx, rx) = mpsc::channel(1);
401        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
402        let expected = empty_server_config();
403        tx.send(Ok(expected.clone())).await.unwrap();
404        let provider = ServerConfigProvider::start(builder).await.unwrap();
405
406        let got = provider.get_config();
407
408        assert!(Arc::ptr_eq(&got, &expected));
409        assert!(provider.stream_healthy());
410    }
411
412    #[tokio::test]
413    async fn single_stream_config_hot_swap() {
414        let (tx, rx) = mpsc::channel(1);
415        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
416
417        let initial = empty_server_config();
418        tx.send(Ok(initial.clone())).await.unwrap();
419        let provider = ServerConfigProvider::start(builder).await.unwrap();
420        let got = provider.get_config();
421        assert!(Arc::ptr_eq(&got, &initial));
422        assert!(provider.stream_healthy());
423
424        for i in 0..10 {
425            let expected = empty_server_config();
426            tx.send(Ok(expected.clone())).await.unwrap();
427
428            tokio::task::yield_now().await;
429            let got = provider.get_config();
430            assert!(
431                Arc::ptr_eq(&got, &expected),
432                "config not updated on iter {i}"
433            );
434            assert!(provider.stream_healthy());
435        }
436    }
437
438    #[tokio::test]
439    async fn stream_failure_triggers_rebuild() {
440        let (tx1, rx1) = mpsc::channel(1);
441        let (tx2, rx2) = mpsc::channel(1);
442        let builder = MockServerConfigStreamBuilder::new(vec![rx1, rx2]);
443        let builds = &builder.builds.clone();
444        let initial = empty_server_config();
445        tx1.send(Ok(initial.clone())).await.unwrap();
446        let provider = ServerConfigProvider::start(builder).await.unwrap();
447        assert!(Arc::ptr_eq(&provider.get_config(), &initial));
448        assert!(provider.stream_healthy());
449
450        tx1.send(Err(ServerConfigStreamError::StreamError(
451            MockError("fake error").into(),
452        )))
453        .await
454        .unwrap();
455
456        // polling to assert provider.stream_healthy
457        // goes to false proved to be flaky due to it
458        // going back to healthy too fast.
459
460        // check that it rebuilt the stream via the provider
461        tokio::task::yield_now().await;
462        assert_eq!(builds.load(Ordering::SeqCst), 2);
463
464        // push a new config and check that it's loaded
465        let new = empty_server_config();
466        tx2.send(Ok(new.clone())).await.unwrap();
467        tokio::task::yield_now().await;
468
469        // check that stream is healthy and new config was loaded
470        assert!(provider.stream_healthy());
471        assert!(Arc::ptr_eq(&provider.get_config(), &new))
472    }
473
474    #[tokio::test(flavor = "current_thread", start_paused = true)]
475    async fn stream_rebuild_goes_into_backoff() {
476        let (tx, rx) = mpsc::channel(1);
477        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
478        let builds = &builder.builds.clone();
479        let initial = empty_server_config();
480        tx.send(Ok(initial.clone())).await.unwrap();
481        let provider = ServerConfigProvider::start(builder).await.unwrap();
482        assert!(Arc::ptr_eq(&provider.get_config(), &initial));
483        assert!(provider.stream_healthy());
484        assert_eq!(builds.load(Ordering::SeqCst), 1);
485
486        tx.send(Err(ServerConfigStreamError::StreamError(
487            MockError("fake error").into(),
488        )))
489        .await
490        .unwrap();
491        tokio::task::yield_now().await;
492        // assert it tried to rebuild stream but is still unhealthy since
493        // the MockServerConfigBuilder will return an error as the streams
494        // vector is empty.
495        assert_eq!(builds.load(Ordering::SeqCst), 2);
496        assert!(!provider.stream_healthy.load(Ordering::Relaxed));
497    }
498}