rustls_config_stream/lib.rs
1//! [`rustls::ServerConfig`] provider backed by an async stream.
2//!
3//! This module exposes a [`ServerConfigProvider`] that holds the "current"
4//! TLS server configuration and updates it whenever a new config arrives from a
5//! user-supplied stream (see [`ServerConfigStreamBuilder`]).
6//!
7//! The background task performs exponential backoff (10ms -> 10s, doubling)
8//! when the stream fails, and attempts to re-create the stream via the builder.
9//! Call [`ServerConfigProvider::get_config`] whenever you need an [`Arc<ServerConfig>`].
10//!
11//! # Overview
12//! - Implement [`ServerConfigStreamBuilder`] to produce a stream of fresh
13//! `ServerConfig` instances (e.g., reading from disk, a secret store, or
14//! watching a certificate manager).
15//! - Start the provider with [`ServerConfigProvider::start`].
16//! - Use [`ServerConfigProvider::get_config`] wherever you need the current
17//! config (e.g., inside an acceptor loop).
18//! - Optionally monitor liveness via [`ServerConfigProvider::stream_healthy`].
19//!
20//! # Tracing
21//! If the `tracing` feature is enabled, the provider will emit diagnostics
22//! (debug/info/error) about updates and reconnection attempts.
23
24use std::{
25 sync::{
26 Arc,
27 atomic::{AtomicBool, Ordering},
28 },
29 time::Duration,
30};
31
32use arc_swap::ArcSwap;
33use rustls::{ServerConfig, server::VerifierBuilderError};
34use thiserror::Error;
35use tokio::time::sleep;
36use tokio_stream::{Stream, StreamExt};
37
38#[cfg(feature = "tracing")]
39use tracing::{debug, error, info};
40
41/// Errors that can occur while building or consuming a server-config stream.
42///
43/// These represent failures either from the user-provided stream/builder
44/// or from `rustls` itself.
45#[derive(Debug, Error)]
46pub enum ServerConfigStreamError {
47 /// The underlying stream produced an error.
48 ///
49 /// This is used to wrap arbitrary stream provider errors.
50 #[error("stream provider error")]
51 StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
52
53 /// The stream completed without yielding an initial `ServerConfig`.
54 ///
55 /// [`ServerConfigProvider::start`] requires at least one item to seed
56 /// the provider; otherwise startup fails with this error.
57 #[error("empty stream")]
58 EmptyStream,
59
60 /// The builder failed to construct a stream.
61 ///
62 /// The provider will surface this when initial construction fails.
63 #[error("could not build stream")]
64 StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
65
66 /// Error originating from `rustls` certificate verifier construction.
67 #[error("cert verifier builder error")]
68 VerifierBuilderError(#[from] VerifierBuilderError),
69
70 /// The builder/stream did not provide a [`CertifiedKey`]
71 #[error("missing server certified key")]
72 MissingCertifiedKey,
73
74 /// The builder/stream did not provide any root certificates
75 #[error("missing root certificates")]
76 MissingRoots,
77
78 /// Wrapper for any `rustls` error.
79 #[error("rustls error")]
80 RustlsError(#[from] rustls::Error),
81}
82
83/// A factory for producing a stream of `rustls::ServerConfig`.
84///
85/// Implement this trait to define how your application sources TLS configs
86/// (e.g., file watchers, secret managers, pull-from-API).
87///
88/// The returned stream should yield *complete* `ServerConfig` values. Each
89/// item replaces the provider's current config atomically (via `ArcSwap`).
90///
91/// # Contract
92/// - `build()` should return a stream that eventually yields at least one
93/// `ServerConfig` during initial startup. If it doesn't, startup will fail
94/// with [`ServerConfigStreamError::EmptyStream`].
95/// - On stream failure, the provider will call `build()` again with backoff.
96/// - Items from the stream should be independent `Arc<ServerConfig>` values.
97///
98/// # Examples
99/// ```rust,ignore
100/// use std::sync::Arc;
101/// use rustls::ServerConfig;
102/// use tokio_stream::{Stream, wrappers::ReceiverStream};
103///
104/// struct MyConfigProvider;
105///
106/// impl ServerConfigStreamBuilder for MyConfigProvider {
107/// type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
108///
109/// async fn build(
110/// &mut self,
111/// ) -> Result<Self::ConfigStream, ServerConfigStreamError> {
112/// // Construct a stream that yields ServerConfig updates.
113/// // See the SPIFFE implementation in `rustls-spiffe` for a full example.
114/// unimplemented!()
115/// }
116/// }
117/// ```
118pub trait ServerConfigStreamBuilder {
119 /// The stream type produced by this builder.
120 ///
121 /// Each item is either a fresh `ServerConfig` or an error explaining why
122 /// the update failed.
123 type ConfigStream: Stream<Item = Result<Arc<ServerConfig>, ServerConfigStreamError>>
124 + Send
125 + Sync
126 + Unpin
127 + 'static;
128
129 /// Asynchronously construct a new configuration stream.
130 ///
131 /// The provider will:
132 /// - call this once during startup to obtain the initial stream,
133 /// - read the *first* config to seed its state,
134 /// - continue to poll the provided stream for new configs
135 /// - upon stream failure or completion, call it again with
136 /// exponential backoff until a new stream is available.
137 fn build(
138 &mut self,
139 ) -> impl std::future::Future<Output = Result<Self::ConfigStream, ServerConfigStreamError>> + Send;
140}
141
142/// Holds the current `ServerConfig` and refreshes it from an async stream.
143///
144/// Internally uses `ArcSwap<ServerConfig>` to provide lock-free, atomic swaps
145/// of the active TLS configuration. Call [`get_config`](Self::get_config) to
146/// obtain an `Arc<ServerConfig>` for acceptors or handshakes.
147///
148/// Liveness of the underlying stream can be checked via
149/// [`stream_healthy`](Self::stream_healthy).
150///
151/// # Concurrency
152/// Reads (`get_config`) are wait-free and do not block updates. Updates occur
153/// on a background task that listens to the user-provided stream.
154///
155/// # Backoff & Recovery
156/// When the stream ends or errors, the provider:
157/// - Marks itself unhealthy,
158/// - Rebuilds the stream via the builder,
159/// - Retries with exponential backoff starting at 10ms and capping at 10s,
160/// - Resets backoff after a successful re-establishment.
161///
162/// # Examples
163/// ```rust,ignore
164/// # use std::sync::Arc;
165/// # use rustls::ServerConfig;
166/// # use your_crate::{ServerConfigProvider, ServerConfigStreamBuilder};
167/// # async fn example(mut builder: impl ServerConfigStreamBuilder + Send + 'static) -> anyhow::Result<()> {
168/// let provider: Arc<ServerConfigProvider> = ServerConfigProvider::start(builder).await?;
169///
170/// // Later, when accepting connections:
171/// let cfg: Arc<ServerConfig> = provider.get_config();
172/// // Use `cfg` with your TLS acceptor/handshake.
173/// # Ok(())
174/// # }
175/// ```
176pub struct ServerConfigProvider {
177 /// The current, atomically-swappable server configuration.
178 inner: ArcSwap<ServerConfig>,
179
180 /// Health flag for the underlying stream (true = healthy).
181 stream_healthy: AtomicBool,
182}
183
184impl ServerConfigProvider {
185 /// Initializes the provider and spawn the background refresh task.
186 ///
187 /// This awaits the first item from the builder's stream to seed the
188 /// internal configuration. It then spawns a task that continuously reads
189 /// subsequent updates, atomically swapping them into place.
190 ///
191 /// On stream failure or completion, the task attempts to rebuild the
192 /// stream using exponential backoff (initial 10ms, max 10s, doubling).
193 ///
194 /// Returns an [`Arc<ServerConfigProvider>`]
195 ///
196 /// # Errors
197 /// - [`ServerConfigStreamError::EmptyStream`]: the initial stream yielded no item.
198 /// - [`ServerConfigStreamError::StreamBuilderError`]: building the stream failed.
199 /// - [`ServerConfigStreamError`] variants wrapping errors from your builder or `rustls`.
200 pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ServerConfigStreamError>
201 where
202 B: ServerConfigStreamBuilder + Send + 'static,
203 {
204 let mut stream = builder.build().await?;
205 let initial = stream
206 .next()
207 .await
208 .ok_or(ServerConfigStreamError::EmptyStream)??;
209 let this = Arc::new(Self {
210 inner: ArcSwap::from(initial),
211 stream_healthy: AtomicBool::new(true),
212 });
213 let ret = this.clone();
214
215 tokio::spawn(async move {
216 let initial_delay = Duration::from_millis(10);
217 let mut delay = initial_delay;
218 let max_delay = Duration::from_secs(10);
219 loop {
220 match stream.next().await {
221 Some(Ok(server_config)) => {
222 this.inner.store(server_config);
223
224 #[cfg(feature = "tracing")]
225 debug!(name: "server_config_provider", "stored updated server config from stream");
226 }
227 Some(Err(_)) | None => {
228 this.stream_healthy.store(false, Ordering::Relaxed);
229
230 #[cfg(feature = "tracing")]
231 error!("config stream returned error or none, trying to build new stream");
232
233 match builder.build().await {
234 Ok(s) => {
235 this.stream_healthy.store(true, Ordering::Relaxed);
236 delay = initial_delay;
237 stream = s;
238
239 #[cfg(feature = "tracing")]
240 info!(name: "server_config_provider", "reestablished server config stream");
241 }
242 Err(err) => {
243 #[cfg(feature = "tracing")]
244 error!(name: "server_config_provider", retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish server config stream");
245
246 sleep(delay).await;
247 delay = (delay * 2).min(max_delay);
248 }
249 };
250 }
251 }
252 }
253 });
254 Ok(ret)
255 }
256
257 /// Returns whether the stream is currently healthy.
258 ///
259 /// This flag is set to `false` when the stream errors or ends, and set
260 /// back to `true` after a successful rebuild.
261 pub fn stream_healthy(&self) -> bool {
262 self.stream_healthy.load(Ordering::Relaxed)
263 }
264
265 /// Get the current `ServerConfig`.
266 ///
267 /// This is a cheap, lock-free read that clones the internal `Arc`.
268 /// Callers can hold onto the returned `Arc<ServerConfig>` as long as
269 /// needed; updates will affect future calls, not the already-held value.
270 pub fn get_config(&self) -> Arc<ServerConfig> {
271 self.inner.load_full()
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use std::{
278 collections::VecDeque,
279 sync::{
280 Arc,
281 atomic::{AtomicUsize, Ordering},
282 },
283 };
284
285 use rustls::ServerConfig;
286 use thiserror::Error;
287 use tokio::sync::{Mutex, mpsc};
288 use tokio_stream::wrappers::ReceiverStream;
289
290 use crate::{ServerConfigProvider, ServerConfigStreamBuilder, ServerConfigStreamError};
291
292 #[derive(Error, Debug)]
293 struct MockError(&'static str);
294 impl std::fmt::Display for MockError {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 f.write_str(self.0)
297 }
298 }
299
300 #[derive(Debug)]
301 struct NoopResolver;
302 impl rustls::server::ResolvesServerCert for NoopResolver {
303 fn resolve(
304 &self,
305 _client_hello: rustls::server::ClientHello<'_>,
306 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
307 None
308 }
309 }
310
311 fn empty_server_config() -> Arc<ServerConfig> {
312 Arc::from(
313 ServerConfig::builder()
314 .with_no_client_auth()
315 .with_cert_resolver(Arc::from(NoopResolver)),
316 )
317 }
318
319 #[derive(Debug)]
320 struct MockServerConfigStreamBuilder {
321 streams:
322 Mutex<VecDeque<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>>,
323 builds: Arc<AtomicUsize>,
324 }
325
326 impl MockServerConfigStreamBuilder {
327 fn new(
328 streams: Vec<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>,
329 ) -> Self {
330 let builds = Arc::from(AtomicUsize::new(0));
331 let streams = Mutex::from(VecDeque::from(streams));
332 Self { streams, builds }
333 }
334 }
335
336 impl ServerConfigStreamBuilder for MockServerConfigStreamBuilder {
337 type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
338
339 async fn build(&mut self) -> Result<Self::ConfigStream, ServerConfigStreamError> {
340 self.builds.fetch_add(1, Ordering::SeqCst);
341 let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
342 ServerConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
343 })?;
344 Ok(ReceiverStream::new(rx))
345 }
346 }
347
348 #[tokio::test]
349 async fn start_fails_given_initial_stream_build_failure() {
350 let builder = MockServerConfigStreamBuilder::new(vec![]);
351
352 let res = ServerConfigProvider::start(builder).await;
353 match res {
354 Err(ServerConfigStreamError::StreamBuilderError(_)) => { /* test pass */ }
355 _ => panic!("expected ServerConfigStreamError::EmptyStream"),
356 }
357 }
358
359 #[tokio::test]
360 async fn start_fails_when_stream_is_empty() {
361 let (tx, rx) = mpsc::channel(1);
362
363 // drop tx so stream returns Poll::Ready(None)
364 std::mem::drop(tx);
365
366 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
367
368 let res = ServerConfigProvider::start(builder).await;
369 match res {
370 Err(ServerConfigStreamError::EmptyStream) => { /* test pass */ }
371 _ => panic!("expected ServerConfigStreamError::EmptyStream"),
372 }
373 }
374
375 #[tokio::test]
376 async fn start_fails_when_first_result_is_err() {
377 let (tx, rx) = mpsc::channel(1);
378 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
379
380 tx.send(Err(ServerConfigStreamError::StreamError(
381 MockError("fake error").into(),
382 )))
383 .await
384 .unwrap();
385
386 let res = ServerConfigProvider::start(builder).await;
387 match res {
388 Err(ServerConfigStreamError::StreamError(err)) => {
389 assert_eq!(err.to_string(), "fake error");
390 }
391 _ => panic!("expected ServerConfigStreamError::EmptyStream"),
392 }
393 }
394
395 #[tokio::test]
396 async fn start_and_initial_config_is_loaded() {
397 let (tx, rx) = mpsc::channel(1);
398 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
399 let expected = empty_server_config();
400 tx.send(Ok(expected.clone())).await.unwrap();
401 let provider = ServerConfigProvider::start(builder).await.unwrap();
402
403 let got = provider.get_config();
404
405 assert!(Arc::ptr_eq(&got, &expected));
406 assert!(provider.stream_healthy());
407 }
408
409 #[tokio::test]
410 async fn single_stream_config_hot_swap() {
411 let (tx, rx) = mpsc::channel(1);
412 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
413
414 let initial = empty_server_config();
415 tx.send(Ok(initial.clone())).await.unwrap();
416 let provider = ServerConfigProvider::start(builder).await.unwrap();
417 let got = provider.get_config();
418 assert!(Arc::ptr_eq(&got, &initial));
419 assert!(provider.stream_healthy());
420
421 for i in 0..10 {
422 let expected = empty_server_config();
423 tx.send(Ok(expected.clone())).await.unwrap();
424
425 tokio::task::yield_now().await;
426 let got = provider.get_config();
427 assert!(
428 Arc::ptr_eq(&got, &expected),
429 "config not updated on iter {i}"
430 );
431 assert!(provider.stream_healthy());
432 }
433 }
434
435 #[tokio::test]
436 async fn stream_failure_triggers_rebuild() {
437 let (tx1, rx1) = mpsc::channel(1);
438 let (tx2, rx2) = mpsc::channel(1);
439 let builder = MockServerConfigStreamBuilder::new(vec![rx1, rx2]);
440 let builds = &builder.builds.clone();
441 let initial = empty_server_config();
442 tx1.send(Ok(initial.clone())).await.unwrap();
443 let provider = ServerConfigProvider::start(builder).await.unwrap();
444 assert!(Arc::ptr_eq(&provider.get_config(), &initial));
445 assert!(provider.stream_healthy());
446
447 // inject an error and wait for provider to go unhealthy
448 tx1.send(Err(ServerConfigStreamError::StreamError(
449 MockError("fake error").into(),
450 )))
451 .await
452 .unwrap();
453
454 // polling to assert provider.stream_healthy
455 // goes to false proved to be flaky due to it
456 // going back to healthy too fast.
457
458 // check that it rebuilt the stream via the provider
459 tokio::task::yield_now().await;
460 assert_eq!(builds.load(Ordering::SeqCst), 2);
461
462 // push a new config and check that it's loaded
463 let new = empty_server_config();
464 tx2.send(Ok(new.clone())).await.unwrap();
465 tokio::task::yield_now().await;
466
467 // check that stream is healthy and new config was loaded
468 assert!(provider.stream_healthy());
469 assert!(Arc::ptr_eq(&provider.get_config(), &new))
470 }
471
472 #[tokio::test(flavor = "current_thread", start_paused = true)]
473 async fn stream_rebuild_goes_into_backoff() {
474 let (tx, rx) = mpsc::channel(1);
475 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
476 let builds = &builder.builds.clone();
477 let initial = empty_server_config();
478 tx.send(Ok(initial.clone())).await.unwrap();
479 let provider = ServerConfigProvider::start(builder).await.unwrap();
480 assert!(Arc::ptr_eq(&provider.get_config(), &initial));
481 assert!(provider.stream_healthy());
482 assert_eq!(builds.load(Ordering::SeqCst), 1);
483
484 tx.send(Err(ServerConfigStreamError::StreamError(
485 MockError("fake error").into(),
486 )))
487 .await
488 .unwrap();
489 tokio::task::yield_now().await;
490 // assert it tried to rebuild stream but is still unhealthy since
491 // the MockServerConfigBuilder will return an error as the streams
492 // vector is empty.
493 assert_eq!(builds.load(Ordering::SeqCst), 2);
494 assert!(!provider.stream_healthy.load(Ordering::Relaxed));
495 }
496}