rustls_config_stream/server.rs
1// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2
3use std::{
4 sync::{
5 Arc,
6 atomic::{AtomicBool, Ordering},
7 },
8 time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use rustls::{ServerConfig, server::VerifierBuilderError};
13use thiserror::Error;
14use tokio::time::sleep;
15use tokio_stream::{Stream, StreamExt};
16
17#[cfg(feature = "tracing")]
18use tracing::{debug, error, info};
19
20/// Errors that can occur while building or consuming a server-config stream.
21///
22/// These represent failures either from the user-provided stream/builder
23/// or from [`rustls`] itself.
24#[derive(Debug, Error)]
25pub enum ServerConfigStreamError {
26 /// The underlying stream produced an error.
27 ///
28 /// This is used to wrap arbitrary stream provider errors.
29 #[error("stream provider error")]
30 StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
31
32 /// The stream completed without yielding an initial [`ServerConfig`].
33 ///
34 /// [`ServerConfigProvider::start`] requires at least one item to seed
35 /// the provider; otherwise startup fails with this error.
36 #[error("empty stream")]
37 EmptyStream,
38
39 /// The builder failed to construct a stream.
40 ///
41 /// The provider will surface this when initial construction fails.
42 #[error("could not build stream")]
43 StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
44
45 /// Error originating from [`rustls`] certificate verifier construction.
46 #[error("cert verifier builder error")]
47 VerifierBuilderError(#[from] VerifierBuilderError),
48
49 /// The builder/stream did not provide a [`rustls::sign::CertifiedKey`]
50 #[error("missing server certified key")]
51 MissingCertifiedKey,
52
53 /// The builder/stream did not provide any root certificates resulting in an empty [`rustls::RootCertStore`]
54 #[error("missing root certificates")]
55 MissingRoots,
56
57 /// Wrapper for any [`rustls::Error`] error.
58 #[error("rustls error")]
59 RustlsError(#[from] rustls::Error),
60}
61
62/// A factory for producing a stream of [`rustls::ServerConfig`].
63///
64/// Implement this trait to define how your application sources TLS configs
65/// (e.g., file watchers, secret managers, pull-from-API).
66///
67/// The returned stream should yield *complete* [`ServerConfig`] values. Each
68/// item replaces the provider's current config atomically (via [`ArcSwap`]).
69///
70/// # Contract
71/// - [`build()`](ServerConfigStreamBuilder::build) should return a stream that eventually yields at least one
72/// [`ServerConfig`] during initial startup. If it doesn't, startup will fail
73/// with [`ServerConfigStreamError::EmptyStream`].
74/// - On stream failure, the provider will call [`build()`](ServerConfigStreamBuilder::build) again with backoff.
75/// - Items from the stream should be independent [`Arc<ServerConfig>`] values.
76///
77/// # Examples
78/// ```no_run
79/// use std::sync::Arc;
80/// use rustls::ServerConfig;
81/// use tokio_stream::{Stream, wrappers::ReceiverStream};
82///
83/// struct MyConfigProvider;
84///
85/// impl ServerConfigStreamBuilder for MyConfigProvider {
86/// type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
87///
88/// async fn build(
89/// &mut self,
90/// ) -> Result<Self::ConfigStream, ServerConfigStreamError> {
91/// // Construct a stream that yields ServerConfig updates.
92/// // See the SPIFFE implementation in [`rustls-spiffe`] for a full example.
93/// unimplemented!()
94/// }
95/// }
96/// ```
97pub trait ServerConfigStreamBuilder {
98 /// The stream type produced by this builder.
99 ///
100 /// Each item is either a fresh [`ServerConfig`] or an error explaining why
101 /// the update failed.
102 type ConfigStream: Stream<Item = Result<Arc<ServerConfig>, ServerConfigStreamError>>
103 + Send
104 + Sync
105 + Unpin
106 + 'static;
107
108 /// Asynchronously construct a new configuration stream.
109 ///
110 /// The provider will:
111 /// - call this once during startup to obtain the initial stream,
112 /// - read the *first* config to seed its state,
113 /// - continue to poll the provided stream for new configs
114 /// - upon stream failure or completion, call it again with
115 /// exponential backoff until a new stream is available.
116 fn build(
117 &mut self,
118 ) -> impl std::future::Future<Output = Result<Self::ConfigStream, ServerConfigStreamError>> + Send;
119}
120
121/// Holds the current [`ServerConfig`] and refreshes it from an async stream.
122///
123/// Internally uses [`ArcSwap<ServerConfig>`] to provide lock-free, atomic swaps
124/// of the active TLS configuration. Call [`get_config`](Self::get_config) to
125/// obtain an [`Arc<ServerConfig>`] for acceptors or handshakes.
126///
127/// Liveness of the underlying stream can be checked via
128/// [`stream_healthy`](Self::stream_healthy).
129///
130/// # Concurrency
131/// Reads [`get_config()`](Self::get_config) are wait-free and do not block updates.
132/// Updates occur on a background task that listens to the user-provided stream.
133///
134/// # Backoff & Recovery
135/// When the stream ends or errors, the provider:
136/// - Marks itself unhealthy,
137/// - Rebuilds the stream via the builder,
138/// - Retries with exponential backoff starting at 10ms and capping at 10s,
139/// - Resets backoff after a successful re-establishment.
140///
141/// # Examples
142/// ```no_run
143/// let config_stream_builder =
144/// SpiffeServerConfigStream::builder(vec!["example.org".try_into().unwrap()]);
145/// let config_provider = ServerConfigProvider::start(config_stream_builder)
146/// .await
147/// .unwrap();
148/// let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
149/// .await
150/// .unwrap();
151///
152/// loop {
153/// let (stream, _) = listener.accept().await.unwrap();
154///
155/// let acceptor =
156/// tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
157/// tokio::pin!(acceptor);
158///
159/// let config_provider = config_provider.clone();
160/// match acceptor.as_mut().await {
161/// Ok(start) => {
162/// tokio::spawn(async move {
163/// if !config_provider.stream_healthy() {
164/// warn!(
165/// "config provider does not have healthy stream; TLS config may be out of date"
166/// );
167/// }
168/// let config = config_provider.get_config();
169/// match start.into_stream(config).await {
170/// Ok(stream) => { /* serve some app (e.g. hyper, tower, axum) */ },
171/// Err(err) => { /* handle error */ }
172/// }
173/// })
174/// }
175/// Err(_) => { /* handle error */ }
176/// }
177/// }
178/// ```
179pub struct ServerConfigProvider {
180 /// The current, atomically-swappable server configuration.
181 inner: ArcSwap<ServerConfig>,
182
183 /// Health flag for the underlying stream (true = healthy).
184 stream_healthy: AtomicBool,
185}
186
187impl ServerConfigProvider {
188 /// Initializes the provider and spawn the background refresh task.
189 ///
190 /// This awaits the first item from the builder's stream to seed the
191 /// internal configuration. It then spawns a task that continuously reads
192 /// subsequent updates, atomically swapping them into place.
193 ///
194 /// On stream failure or completion, the task attempts to rebuild the
195 /// stream using exponential backoff (initial 10ms, max 10s, doubling).
196 ///
197 /// Returns an [`Arc<ServerConfigProvider>`]
198 ///
199 /// # Errors
200 /// - [`ServerConfigStreamError::EmptyStream`]: the initial stream yielded no item.
201 /// - [`ServerConfigStreamError::StreamBuilderError`]: building the stream failed.
202 /// - [`ServerConfigStreamError`] variants wrapping errors from your builder or `rustls`.
203 pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ServerConfigStreamError>
204 where
205 B: ServerConfigStreamBuilder + Send + 'static,
206 {
207 let mut stream = builder.build().await?;
208 let initial = stream
209 .next()
210 .await
211 .ok_or(ServerConfigStreamError::EmptyStream)??;
212 let this = Arc::new(Self {
213 inner: ArcSwap::from(initial),
214 stream_healthy: AtomicBool::new(true),
215 });
216 let ret = this.clone();
217
218 tokio::spawn(async move {
219 let initial_delay = Duration::from_millis(10);
220 let mut delay = initial_delay;
221 let max_delay = Duration::from_secs(10);
222 loop {
223 match stream.next().await {
224 Some(Ok(server_config)) => {
225 this.inner.store(server_config);
226
227 #[cfg(feature = "tracing")]
228 debug!(name: "server_config_provider", "stored updated server config from stream");
229 }
230 Some(Err(_)) | None => {
231 this.stream_healthy.store(false, Ordering::Relaxed);
232
233 #[cfg(feature = "tracing")]
234 error!("config stream returned error or none, trying to build new stream");
235
236 match builder.build().await {
237 Ok(s) => {
238 this.stream_healthy.store(true, Ordering::Relaxed);
239 delay = initial_delay;
240 stream = s;
241
242 #[cfg(feature = "tracing")]
243 info!(name: "server_config_provider", "reestablished server config stream");
244 }
245 Err(err) => {
246 #[cfg(feature = "tracing")]
247 error!(name: "server_config_provider", retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish server config stream");
248
249 sleep(delay).await;
250 delay = (delay * 2).min(max_delay);
251 }
252 };
253 }
254 }
255 }
256 });
257 Ok(ret)
258 }
259
260 /// Returns whether the stream is currently healthy.
261 ///
262 /// This flag is set to `false` when the stream errors or ends, and set
263 /// back to `true` after a successful rebuild.
264 pub fn stream_healthy(&self) -> bool {
265 self.stream_healthy.load(Ordering::Relaxed)
266 }
267
268 /// Get the current [`ServerConfig`].
269 ///
270 /// This is a cheap, lock-free read that loads the internal [`ArcSwap<ServerConfig>`] into an [`Arc<ServerConfig>`]
271 /// Callers can hold onto the returned [`Arc<ServerConfig>`] as long as
272 /// needed; updates will affect future calls, not the already-held value.
273 pub fn get_config(&self) -> Arc<ServerConfig> {
274 self.inner.load_full()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use std::{
281 collections::VecDeque,
282 sync::{
283 Arc,
284 atomic::{AtomicUsize, Ordering},
285 },
286 };
287
288 use rustls::ServerConfig;
289 use thiserror::Error;
290 use tokio::sync::{Mutex, mpsc};
291 use tokio_stream::wrappers::ReceiverStream;
292
293 use crate::{ServerConfigProvider, ServerConfigStreamBuilder, ServerConfigStreamError};
294
295 #[derive(Error, Debug)]
296 struct MockError(&'static str);
297 impl std::fmt::Display for MockError {
298 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299 f.write_str(self.0)
300 }
301 }
302
303 #[derive(Debug)]
304 struct NoopResolver;
305 impl rustls::server::ResolvesServerCert for NoopResolver {
306 fn resolve(
307 &self,
308 _client_hello: rustls::server::ClientHello<'_>,
309 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
310 None
311 }
312 }
313
314 fn empty_server_config() -> Arc<ServerConfig> {
315 Arc::from(
316 ServerConfig::builder()
317 .with_no_client_auth()
318 .with_cert_resolver(Arc::from(NoopResolver)),
319 )
320 }
321
322 #[derive(Debug)]
323 struct MockServerConfigStreamBuilder {
324 streams:
325 Mutex<VecDeque<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>>,
326 builds: Arc<AtomicUsize>,
327 }
328
329 impl MockServerConfigStreamBuilder {
330 fn new(
331 streams: Vec<mpsc::Receiver<Result<Arc<ServerConfig>, ServerConfigStreamError>>>,
332 ) -> Self {
333 let builds = Arc::from(AtomicUsize::new(0));
334 let streams = Mutex::from(VecDeque::from(streams));
335 Self { streams, builds }
336 }
337 }
338
339 impl ServerConfigStreamBuilder for MockServerConfigStreamBuilder {
340 type ConfigStream = ReceiverStream<Result<Arc<ServerConfig>, ServerConfigStreamError>>;
341
342 async fn build(&mut self) -> Result<Self::ConfigStream, ServerConfigStreamError> {
343 self.builds.fetch_add(1, Ordering::SeqCst);
344 let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
345 ServerConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
346 })?;
347 Ok(ReceiverStream::new(rx))
348 }
349 }
350
351 #[tokio::test]
352 async fn start_fails_given_initial_stream_build_failure() {
353 let builder = MockServerConfigStreamBuilder::new(vec![]);
354
355 let res = ServerConfigProvider::start(builder).await;
356 match res {
357 Err(ServerConfigStreamError::StreamBuilderError(_)) => { /* test pass */ }
358 _ => panic!("expected ServerConfigStreamError::EmptyStream"),
359 }
360 }
361
362 #[tokio::test]
363 async fn start_fails_when_stream_is_empty() {
364 let (tx, rx) = mpsc::channel(1);
365
366 // drop tx so stream returns Poll::Ready(None)
367 std::mem::drop(tx);
368
369 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
370
371 let res = ServerConfigProvider::start(builder).await;
372 match res {
373 Err(ServerConfigStreamError::EmptyStream) => { /* test pass */ }
374 _ => panic!("expected ServerConfigStreamError::EmptyStream"),
375 }
376 }
377
378 #[tokio::test]
379 async fn start_fails_when_first_result_is_err() {
380 let (tx, rx) = mpsc::channel(1);
381 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
382
383 tx.send(Err(ServerConfigStreamError::StreamError(
384 MockError("fake error").into(),
385 )))
386 .await
387 .unwrap();
388
389 let res = ServerConfigProvider::start(builder).await;
390 match res {
391 Err(ServerConfigStreamError::StreamError(err)) => {
392 assert_eq!(err.to_string(), "fake error");
393 }
394 _ => panic!("expected ServerConfigStreamError::EmptyStream"),
395 }
396 }
397
398 #[tokio::test]
399 async fn start_and_initial_config_is_loaded() {
400 let (tx, rx) = mpsc::channel(1);
401 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
402 let expected = empty_server_config();
403 tx.send(Ok(expected.clone())).await.unwrap();
404 let provider = ServerConfigProvider::start(builder).await.unwrap();
405
406 let got = provider.get_config();
407
408 assert!(Arc::ptr_eq(&got, &expected));
409 assert!(provider.stream_healthy());
410 }
411
412 #[tokio::test]
413 async fn single_stream_config_hot_swap() {
414 let (tx, rx) = mpsc::channel(1);
415 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
416
417 let initial = empty_server_config();
418 tx.send(Ok(initial.clone())).await.unwrap();
419 let provider = ServerConfigProvider::start(builder).await.unwrap();
420 let got = provider.get_config();
421 assert!(Arc::ptr_eq(&got, &initial));
422 assert!(provider.stream_healthy());
423
424 for i in 0..10 {
425 let expected = empty_server_config();
426 tx.send(Ok(expected.clone())).await.unwrap();
427
428 tokio::task::yield_now().await;
429 let got = provider.get_config();
430 assert!(
431 Arc::ptr_eq(&got, &expected),
432 "config not updated on iter {i}"
433 );
434 assert!(provider.stream_healthy());
435 }
436 }
437
438 #[tokio::test]
439 async fn stream_failure_triggers_rebuild() {
440 let (tx1, rx1) = mpsc::channel(1);
441 let (tx2, rx2) = mpsc::channel(1);
442 let builder = MockServerConfigStreamBuilder::new(vec![rx1, rx2]);
443 let builds = &builder.builds.clone();
444 let initial = empty_server_config();
445 tx1.send(Ok(initial.clone())).await.unwrap();
446 let provider = ServerConfigProvider::start(builder).await.unwrap();
447 assert!(Arc::ptr_eq(&provider.get_config(), &initial));
448 assert!(provider.stream_healthy());
449
450 tx1.send(Err(ServerConfigStreamError::StreamError(
451 MockError("fake error").into(),
452 )))
453 .await
454 .unwrap();
455
456 // polling to assert provider.stream_healthy
457 // goes to false proved to be flaky due to it
458 // going back to healthy too fast.
459
460 // check that it rebuilt the stream via the provider
461 tokio::task::yield_now().await;
462 assert_eq!(builds.load(Ordering::SeqCst), 2);
463
464 // push a new config and check that it's loaded
465 let new = empty_server_config();
466 tx2.send(Ok(new.clone())).await.unwrap();
467 tokio::task::yield_now().await;
468
469 // check that stream is healthy and new config was loaded
470 assert!(provider.stream_healthy());
471 assert!(Arc::ptr_eq(&provider.get_config(), &new))
472 }
473
474 #[tokio::test(flavor = "current_thread", start_paused = true)]
475 async fn stream_rebuild_goes_into_backoff() {
476 let (tx, rx) = mpsc::channel(1);
477 let builder = MockServerConfigStreamBuilder::new(vec![rx]);
478 let builds = &builder.builds.clone();
479 let initial = empty_server_config();
480 tx.send(Ok(initial.clone())).await.unwrap();
481 let provider = ServerConfigProvider::start(builder).await.unwrap();
482 assert!(Arc::ptr_eq(&provider.get_config(), &initial));
483 assert!(provider.stream_healthy());
484 assert_eq!(builds.load(Ordering::SeqCst), 1);
485
486 tx.send(Err(ServerConfigStreamError::StreamError(
487 MockError("fake error").into(),
488 )))
489 .await
490 .unwrap();
491 tokio::task::yield_now().await;
492 // assert it tried to rebuild stream but is still unhealthy since
493 // the MockServerConfigBuilder will return an error as the streams
494 // vector is empty.
495 assert_eq!(builds.load(Ordering::SeqCst), 2);
496 assert!(!provider.stream_healthy.load(Ordering::Relaxed));
497 }
498}