rustls_config_stream/
lib.rs

1//! [`rustls::ServerConfig`] provider backed by an async stream.
2//!
3//! This module exposes a [`ServerConfigProvider`] that holds the "current"
4//! TLS server configuration and updates it whenever a new config arrives from a
5//! user-supplied stream (see [`ServerConfigStreamBuilder`]).
6//!
7//! The background task performs exponential backoff (10ms -> 10s, doubling)
8//! when the stream fails, and attempts to re-create the stream via the builder.
9//! Call [`ServerConfigProvider::get_config`] whenever you need an [`Arc<ServerConfig>`].
10//!
11//! # Overview
12//! - Implement [`ServerConfigStreamBuilder`] to produce a stream of fresh
13//!   `ServerConfig` instances (e.g., reading from disk, a secret store, or
14//!   watching a certificate manager).
15//! - Start the provider with [`ServerConfigProvider::start`].
16//! - Use [`ServerConfigProvider::get_config`] wherever you need the current
17//!   config (e.g., inside an acceptor loop).
18//! - Optionally monitor liveness via [`ServerConfigProvider::stream_healthy`].
19//!
20//! # Tracing
21//! If the `tracing` feature is enabled, the provider will emit diagnostics
22//! (debug/info/error) about updates and reconnection attempts.
23
24use std::{
25    sync::{
26        Arc,
27        atomic::{AtomicBool, Ordering},
28    },
29    time::Duration,
30};
31
32use arc_swap::ArcSwap;
33use rustls::{ServerConfig, server::VerifierBuilderError};
34use thiserror::Error;
35use tokio::time::sleep;
36use tokio_stream::{Stream, StreamExt};
37
38#[cfg(feature = "tracing")]
39use tracing::{debug, error, info};
40
41/// Errors that can occur while building or consuming a server-config stream.
42///
43/// These represent failures either from the user-provided stream/builder
44/// or from `rustls` itself.
45#[derive(Debug, Error)]
46pub enum ServerConfigStreamError {
47    /// The underlying stream produced an error.
48    ///
49    /// This is used to wrap arbitrary stream provider errors.
50    #[error("stream provider error")]
51    StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
52
53    /// The stream completed without yielding an initial `ServerConfig`.
54    ///
55    /// [`ServerConfigProvider::start`] requires at least one item to seed
56    /// the provider; otherwise startup fails with this error.
57    #[error("empty stream")]
58    EmptyStream,
59
60    /// The builder failed to construct a stream.
61    ///
62    /// The provider will surface this when initial construction fails.
63    #[error("could not build stream")]
64    StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
65
66    /// Error originating from `rustls` certificate verifier construction.
67    #[error("cert verifier builder error")]
68    VerifierBuilderError(#[from] VerifierBuilderError),
69
70    /// The builder/stream did not provide a [`CertifiedKey`]
71    #[error("missing server certified key")]
72    MissingCertifiedKey,
73
74    /// The builder/stream did not provide any root certificates
75    #[error("missing root certificates")]
76    MissingRoots,
77
78    /// Wrapper for any `rustls` error.
79    #[error("rustls error")]
80    RustlsError(#[from] rustls::Error),
81}
82
83/// A factory for producing a stream of `rustls::ServerConfig`.
84///
85/// Implement this trait to define how your application sources TLS configs
86/// (e.g., file watchers, secret managers, pull-from-API).
87///
88/// The returned stream should yield *complete* `ServerConfig` values. Each
89/// item replaces the provider's current config atomically (via `ArcSwap`).
90///
91/// # Contract
92/// - `build()` should return a stream that eventually yields at least one
93///   `ServerConfig` during initial startup. If it doesn't, startup will fail
94///   with [`ServerConfigStreamError::EmptyStream`].
95/// - On stream failure, the provider will call `build()` again with backoff.
96/// - Items from the stream should be independent `Arc<ServerConfig>` values.
97///
98/// # Examples
99/// ```rust,ignore
100/// use std::sync::Arc;
101/// use rustls::ServerConfig;
102/// use tokio_stream::{Stream, wrappers::ReceiverStream};
103///
104/// struct MyConfigProvider;
105///
106/// impl ServerConfigStreamBuilder for MyConfigProvider {
107///     type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
108///
109///     async fn build(
110///         &mut self,
111///     ) -> Result<Self::ConfigStream, ServerConfigStreamError> {
112///         // Construct a stream that yields ServerConfig updates.
113///         // See the SPIFFE implementation in `rustls-spiffe` for a full example.
114///         unimplemented!()
115///     }
116/// }
117/// ```
118pub trait ServerConfigStreamBuilder {
119    /// The stream type produced by this builder.
120    ///
121    /// Each item is either a fresh `ServerConfig` or an error explaining why
122    /// the update failed.
123    type ConfigStream: Stream<Item = Result<Arc<ServerConfig>, ServerConfigStreamError>>
124        + Send
125        + Sync
126        + Unpin
127        + 'static;
128
129    /// Asynchronously construct a new configuration stream.
130    ///
131    /// The provider will:
132    /// - call this once during startup to obtain the initial stream,
133    /// - read the *first* config to seed its state,
134    /// - continue to poll the provided stream for new configs
135    /// - upon stream failure or completion, call it again with
136    ///   exponential backoff until a new stream is available.
137    fn build(
138        &mut self,
139    ) -> impl std::future::Future<Output = Result<Self::ConfigStream, ServerConfigStreamError>> + Send;
140}
141
142/// Holds the current `ServerConfig` and refreshes it from an async stream.
143///
144/// Internally uses `ArcSwap<ServerConfig>` to provide lock-free, atomic swaps
145/// of the active TLS configuration. Call [`get_config`](Self::get_config) to
146/// obtain an `Arc<ServerConfig>` for acceptors or handshakes.
147///
148/// Liveness of the underlying stream can be checked via
149/// [`stream_healthy`](Self::stream_healthy).
150///
151/// # Concurrency
152/// Reads (`get_config`) are wait-free and do not block updates. Updates occur
153/// on a background task that listens to the user-provided stream.
154///
155/// # Backoff & Recovery
156/// When the stream ends or errors, the provider:
157/// - Marks itself unhealthy,
158/// - Rebuilds the stream via the builder,
159/// - Retries with exponential backoff starting at 10ms and capping at 10s,
160/// - Resets backoff after a successful re-establishment.
161///
162/// # Examples
163/// ```rust,ignore
164/// # use std::sync::Arc;
165/// # use rustls::ServerConfig;
166/// # use your_crate::{ServerConfigProvider, ServerConfigStreamBuilder};
167/// # async fn example(mut builder: impl ServerConfigStreamBuilder + Send + 'static) -> anyhow::Result<()> {
168/// let provider: Arc<ServerConfigProvider> = ServerConfigProvider::start(builder).await?;
169///
170/// // Later, when accepting connections:
171/// let cfg: Arc<ServerConfig> = provider.get_config();
172/// // Use `cfg` with your TLS acceptor/handshake.
173/// # Ok(())
174/// # }
175/// ```
176pub struct ServerConfigProvider {
177    /// The current, atomically-swappable server configuration.
178    inner: ArcSwap<ServerConfig>,
179
180    /// Health flag for the underlying stream (true = healthy).
181    stream_healthy: AtomicBool,
182}
183
184impl ServerConfigProvider {
185    /// Initializes the provider and spawn the background refresh task.
186    ///
187    /// This awaits the first item from the builder's stream to seed the
188    /// internal configuration. It then spawns a task that continuously reads
189    /// subsequent updates, atomically swapping them into place.
190    ///
191    /// On stream failure or completion, the task attempts to rebuild the
192    /// stream using exponential backoff (initial 10ms, max 10s, doubling).
193    ///
194    /// Returns an [`Arc<ServerConfigProvider>`]
195    ///
196    /// # Errors
197    /// - [`ServerConfigStreamError::EmptyStream`]: the initial stream yielded no item.
198    /// - [`ServerConfigStreamError::StreamBuilderError`]: building the stream failed.
199    /// - [`ServerConfigStreamError`] variants wrapping errors from your builder or `rustls`.
200    pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ServerConfigStreamError>
201    where
202        B: ServerConfigStreamBuilder + Send + 'static,
203    {
204        let mut stream = builder.build().await?;
205        let initial = stream
206            .next()
207            .await
208            .ok_or(ServerConfigStreamError::EmptyStream)??;
209        let this = Arc::new(Self {
210            inner: ArcSwap::from(initial),
211            stream_healthy: AtomicBool::new(true),
212        });
213        let ret = this.clone();
214
215        tokio::spawn(async move {
216            let initial_delay = Duration::from_millis(10);
217            let mut delay = initial_delay;
218            let max_delay = Duration::from_secs(10);
219            loop {
220                match stream.next().await {
221                    Some(Ok(server_config)) => {
222                        this.inner.store(server_config);
223
224                        #[cfg(feature = "tracing")]
225                        debug!(name: "server_config_provider", "stored updated server config from stream");
226                    }
227                    Some(Err(_)) | None => {
228                        this.stream_healthy.store(false, Ordering::Relaxed);
229
230                        #[cfg(feature = "tracing")]
231                        error!("config stream returned error or none, trying to build new stream");
232
233                        match builder.build().await {
234                            Ok(s) => {
235                                this.stream_healthy.store(true, Ordering::Relaxed);
236                                delay = initial_delay;
237                                stream = s;
238
239                                #[cfg(feature = "tracing")]
240                                info!(name: "server_config_provider", "reestablished server config stream");
241                            }
242                            Err(err) => {
243                                #[cfg(feature = "tracing")]
244                                error!(name: "server_config_provider", retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish server config stream");
245
246                                sleep(delay).await;
247                                delay = (delay * 2).min(max_delay);
248                            }
249                        };
250                    }
251                }
252            }
253        });
254        Ok(ret)
255    }
256
257    /// Returns whether the stream is currently healthy.
258    ///
259    /// This flag is set to `false` when the stream errors or ends, and set
260    /// back to `true` after a successful rebuild.
261    pub fn stream_healthy(&self) -> bool {
262        self.stream_healthy.load(Ordering::Relaxed)
263    }
264
265    /// Get the current `ServerConfig`.
266    ///
267    /// This is a cheap, lock-free read that clones the internal `Arc`.
268    /// Callers can hold onto the returned `Arc<ServerConfig>` as long as
269    /// needed; updates will affect future calls, not the already-held value.
270    pub fn get_config(&self) -> Arc<ServerConfig> {
271        self.inner.load_full()
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use std::{
278        collections::VecDeque,
279        sync::{
280            Arc,
281            atomic::{AtomicUsize, Ordering},
282        },
283    };
284
285    use rustls::ServerConfig;
286    use thiserror::Error;
287    use tokio::sync::{Mutex, mpsc};
288    use tokio_stream::wrappers::ReceiverStream;
289
290    use crate::{ServerConfigProvider, ServerConfigStreamBuilder, ServerConfigStreamError};
291
292    #[derive(Error, Debug)]
293    struct MockError(&'static str);
294    impl std::fmt::Display for MockError {
295        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296            f.write_str(self.0)
297        }
298    }
299
300    #[derive(Debug)]
301    struct NoopResolver;
302    impl rustls::server::ResolvesServerCert for NoopResolver {
303        fn resolve(
304            &self,
305            _client_hello: rustls::server::ClientHello<'_>,
306        ) -> Option<Arc<rustls::sign::CertifiedKey>> {
307            None
308        }
309    }
310
311    fn empty_server_config() -> Arc<ServerConfig> {
312        Arc::from(
313            ServerConfig::builder()
314                .with_no_client_auth()
315                .with_cert_resolver(Arc::from(NoopResolver)),
316        )
317    }
318
319    #[derive(Debug)]
320    struct MockServerConfigStreamBuilder {
321        streams:
322            Mutex<VecDeque<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>>,
323        builds: Arc<AtomicUsize>,
324    }
325
326    impl MockServerConfigStreamBuilder {
327        fn new(
328            streams: Vec<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>,
329        ) -> Self {
330            let builds = Arc::from(AtomicUsize::new(0));
331            let streams = Mutex::from(VecDeque::from(streams));
332            Self { streams, builds }
333        }
334    }
335
336    impl ServerConfigStreamBuilder for MockServerConfigStreamBuilder {
337        type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
338
339        async fn build(&mut self) -> Result<Self::ConfigStream, ServerConfigStreamError> {
340            self.builds.fetch_add(1, Ordering::SeqCst);
341            let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
342                ServerConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
343            })?;
344            Ok(ReceiverStream::new(rx))
345        }
346    }
347
348    #[tokio::test]
349    async fn start_fails_given_initial_stream_build_failure() {
350        let builder = MockServerConfigStreamBuilder::new(vec![]);
351
352        let res = ServerConfigProvider::start(builder).await;
353        match res {
354            Err(ServerConfigStreamError::StreamBuilderError(_)) => { /* test pass */ }
355            _ => panic!("expected ServerConfigStreamError::EmptyStream"),
356        }
357    }
358
359    #[tokio::test]
360    async fn start_fails_when_stream_is_empty() {
361        let (tx, rx) = mpsc::channel(1);
362
363        // drop tx so stream returns Poll::Ready(None)
364        std::mem::drop(tx);
365
366        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
367
368        let res = ServerConfigProvider::start(builder).await;
369        match res {
370            Err(ServerConfigStreamError::EmptyStream) => { /* test pass */ }
371            _ => panic!("expected ServerConfigStreamError::EmptyStream"),
372        }
373    }
374
375    #[tokio::test]
376    async fn start_fails_when_first_result_is_err() {
377        let (tx, rx) = mpsc::channel(1);
378        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
379
380        tx.send(Err(ServerConfigStreamError::StreamError(
381            MockError("fake error").into(),
382        )))
383        .await
384        .unwrap();
385
386        let res = ServerConfigProvider::start(builder).await;
387        match res {
388            Err(ServerConfigStreamError::StreamError(err)) => {
389                assert_eq!(err.to_string(), "fake error");
390            }
391            _ => panic!("expected ServerConfigStreamError::EmptyStream"),
392        }
393    }
394
395    #[tokio::test]
396    async fn start_and_initial_config_is_loaded() {
397        let (tx, rx) = mpsc::channel(1);
398        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
399        let expected = empty_server_config();
400        tx.send(Ok(expected.clone())).await.unwrap();
401        let provider = ServerConfigProvider::start(builder).await.unwrap();
402
403        let got = provider.get_config();
404
405        assert!(Arc::ptr_eq(&got, &expected));
406        assert!(provider.stream_healthy());
407    }
408
409    #[tokio::test]
410    async fn single_stream_config_hot_swap() {
411        let (tx, rx) = mpsc::channel(1);
412        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
413
414        let initial = empty_server_config();
415        tx.send(Ok(initial.clone())).await.unwrap();
416        let provider = ServerConfigProvider::start(builder).await.unwrap();
417        let got = provider.get_config();
418        assert!(Arc::ptr_eq(&got, &initial));
419        assert!(provider.stream_healthy());
420
421        for i in 0..10 {
422            let expected = empty_server_config();
423            tx.send(Ok(expected.clone())).await.unwrap();
424
425            tokio::task::yield_now().await;
426            let got = provider.get_config();
427            assert!(
428                Arc::ptr_eq(&got, &expected),
429                "config not updated on iter {i}"
430            );
431            assert!(provider.stream_healthy());
432        }
433    }
434
435    #[tokio::test]
436    async fn stream_failure_triggers_rebuild() {
437        let (tx1, rx1) = mpsc::channel(1);
438        let (tx2, rx2) = mpsc::channel(1);
439        let builder = MockServerConfigStreamBuilder::new(vec![rx1, rx2]);
440        let builds = &builder.builds.clone();
441        let initial = empty_server_config();
442        tx1.send(Ok(initial.clone())).await.unwrap();
443        let provider = ServerConfigProvider::start(builder).await.unwrap();
444        assert!(Arc::ptr_eq(&provider.get_config(), &initial));
445        assert!(provider.stream_healthy());
446
447        // inject an error and wait for provider to go unhealthy
448        tx1.send(Err(ServerConfigStreamError::StreamError(
449            MockError("fake error").into(),
450        )))
451        .await
452        .unwrap();
453
454        // polling to assert provider.stream_healthy
455        // goes to false proved to be flaky due to it
456        // going back to healthy too fast.
457
458        // check that it rebuilt the stream via the provider
459        tokio::task::yield_now().await;
460        assert_eq!(builds.load(Ordering::SeqCst), 2);
461
462        // push a new config and check that it's loaded
463        let new = empty_server_config();
464        tx2.send(Ok(new.clone())).await.unwrap();
465        tokio::task::yield_now().await;
466
467        // check that stream is healthy and new config was loaded
468        assert!(provider.stream_healthy());
469        assert!(Arc::ptr_eq(&provider.get_config(), &new))
470    }
471
472    #[tokio::test(flavor = "current_thread", start_paused = true)]
473    async fn stream_rebuild_goes_into_backoff() {
474        let (tx, rx) = mpsc::channel(1);
475        let builder = MockServerConfigStreamBuilder::new(vec![rx]);
476        let builds = &builder.builds.clone();
477        let initial = empty_server_config();
478        tx.send(Ok(initial.clone())).await.unwrap();
479        let provider = ServerConfigProvider::start(builder).await.unwrap();
480        assert!(Arc::ptr_eq(&provider.get_config(), &initial));
481        assert!(provider.stream_healthy());
482        assert_eq!(builds.load(Ordering::SeqCst), 1);
483
484        tx.send(Err(ServerConfigStreamError::StreamError(
485            MockError("fake error").into(),
486        )))
487        .await
488        .unwrap();
489        tokio::task::yield_now().await;
490        // assert it tried to rebuild stream but is still unhealthy since
491        // the MockServerConfigBuilder will return an error as the streams
492        // vector is empty.
493        assert_eq!(builds.load(Ordering::SeqCst), 2);
494        assert!(!provider.stream_healthy.load(Ordering::Relaxed));
495    }
496}