rustls_config_stream/
client.rs

1// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2
3use std::{
4    sync::{
5        Arc,
6        atomic::{AtomicBool, Ordering},
7    },
8    time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use rustls::{ClientConfig, client::VerifierBuilderError};
13use thiserror::Error;
14use tokio::time::sleep;
15use tokio_stream::{Stream, StreamExt};
16
17#[cfg(feature = "tracing")]
18use tracing::{debug, error, info};
19
20/// Errors that can occur while building or consuming a client-config stream.
21///
22/// These represent failures either from the user-provided stream/builder
23/// or from [`rustls`] itself.
24#[derive(Debug, Error)]
25pub enum ClientConfigStreamError {
26    /// The underlying stream produced an error.
27    ///
28    /// This is used to wrap arbitrary stream provider errors.
29    #[error("stream provider error")]
30    StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
31
32    /// The stream completed without yielding an initial [`ClientConfig`].
33    ///
34    /// [`ClientConfigProvider::start`] requires at least one item to seed
35    /// the provider; otherwise startup fails with this error.
36    #[error("empty stream")]
37    EmptyStream,
38
39    /// The builder failed to construct a stream.
40    ///
41    /// The provider will surface this when initial construction fails.
42    #[error("could not build stream")]
43    StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
44
45    /// Error originating from [`rustls`] certificate verifier construction.
46    #[error("cert verifier builder error")]
47    VerifierBuilderError(#[from] VerifierBuilderError),
48
49    /// The builder/stream did not provide a [`rustls::sign::CertifiedKey`]
50    #[error("missing client certified key")]
51    MissingCertifiedKey,
52
53    /// The builder/stream did not provide any root certificates resulting in an empty [`rustls::RootCertStore`]
54    #[error("missing root certificates")]
55    MissingRoots,
56
57    /// Wrapper for any [`rustls`] error.
58    #[error("rustls error")]
59    RustlsError(#[from] rustls::Error),
60}
61
62/// A factory for producing a stream of [`rustls::ClientConfig`].
63///
64/// Implement this trait to define how your application sources TLS configs
65/// (e.g., file watchers, secret managers, pull-from-API).
66///
67/// The returned stream should yield *complete* [`ClientConfig`] values. Each
68/// item replaces the provider's current config atomically (via [`ArcSwap`]).
69///
70/// # Contract
71/// - [`build()`](ClientConfigStreamBuilder::build) should return a stream that eventually yields at least one
72///   [`ClientConfig`] during initial startup. If it doesn't, startup will fail
73///   with [`ClientConfigStreamError::EmptyStream`].
74/// - On stream failure, the provider will call [`build()`](ClientConfigStreamBuilder::build) again with backoff.
75/// - Items from the stream should be independent [`Arc<ClientConfig>`] values.
76///
77/// # Examples
78/// ```rust,ignore
79/// use std::sync::Arc;
80/// use rustls::ClientConfig;
81/// use tokio_stream::{Stream, wrappers::ReceiverStream};
82///
83/// struct MyConfigProvider;
84///
85/// impl ClientConfigStreamBuilder for MyConfigProvider {
86///     type ConfigStream = ReceiverStream<Result<Arc<ClientConfig>, ClientConfigStreamError>>;
87///
88///     async fn build(
89///         &mut self,
90///     ) -> Result<Self::ConfigStream, ClientConfigStreamError> {
91///         // Construct a stream that yields ClientConfig updates.
92///         // See the SPIFFE implementation in `rustls-spiffe` for a full example.
93///         unimplemented!()
94///     }
95/// }
96/// ```
97pub trait ClientConfigStreamBuilder {
98    /// The stream type produced by this builder.
99    ///
100    /// Each item is either a fresh [`ClientConfig`] or an error explaining why
101    /// the update failed.
102    type ConfigStream: Stream<Item = Result<Arc<ClientConfig>, ClientConfigStreamError>>
103        + Send
104        + Sync
105        + Unpin
106        + 'static;
107
108    /// Asynchronously construct a new configuration stream.
109    ///
110    /// The provider will:
111    /// - call this once during startup to obtain the initial stream,
112    /// - read the *first* config to seed its state,
113    /// - continue to poll the provided stream for new configs
114    /// - upon stream failure or completion, call it again with
115    ///   exponential backoff until a new stream is available.
116    fn build(
117        &mut self,
118    ) -> impl std::future::Future<Output = Result<Self::ConfigStream, ClientConfigStreamError>> + Send;
119}
120
121/// Holds the current [`ClientConfig`] and refreshes it from an async stream.
122///
123/// Internally uses [`ArcSwap<ClientConfig>`] to provide lock-free, atomic swaps
124/// of the active TLS configuration. Call [`get_config`](Self::get_config) to
125/// obtain an [`Arc<ClientConfig>`] for acceptors or handshakes.
126///
127/// Liveness of the underlying stream can be checked via
128/// [`stream_healthy`](Self::stream_healthy).
129///
130/// # Concurrency
131/// Reads [`get_config`](ClientConfigProvider::get_config) are wait-free and do not block updates.
132/// Updates occur on a background task that listens to the user-provided stream.
133///
134/// # Backoff & Recovery
135/// When the stream ends or errors, the provider:
136/// - Marks itself unhealthy,
137/// - Rebuilds the stream via the builder,
138/// - Retries with exponential backoff starting at 10ms and capping at 10s,
139/// - Resets backoff after a successful re-establishment.
140pub struct ClientConfigProvider {
141    /// The current, atomically-swappable client configuration.
142    inner: ArcSwap<ClientConfig>,
143
144    /// Health flag for the underlying stream (true = healthy).
145    stream_healthy: AtomicBool,
146}
147
148impl ClientConfigProvider {
149    /// Initializes the provider and spawn the background refresh task.
150    ///
151    /// This awaits the first item from the builder's stream to seed the
152    /// internal configuration. It then spawns a task that continuously reads
153    /// subsequent updates, atomically swapping them into place.
154    ///
155    /// On stream failure or completion, the task attempts to rebuild the
156    /// stream using exponential backoff (initial 10ms, max 10s, doubling).
157    ///
158    /// Returns an [`Arc<ClientConfigProvider>`]
159    ///
160    /// # Errors
161    /// - [`ClientConfigStreamError::EmptyStream`]: the initial stream yielded no item.
162    /// - [`ClientConfigStreamError::StreamBuilderError`]: building the stream failed.
163    /// - [`ClientConfigStreamError`] variants wrapping errors from your builder or `rustls`.
164    pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ClientConfigStreamError>
165    where
166        B: ClientConfigStreamBuilder + Send + 'static,
167    {
168        let mut stream = builder.build().await?;
169        let initial = stream
170            .next()
171            .await
172            .ok_or(ClientConfigStreamError::EmptyStream)??;
173        let this = Arc::new(Self {
174            inner: ArcSwap::from(initial),
175            stream_healthy: AtomicBool::new(true),
176        });
177        let ret = this.clone();
178
179        tokio::spawn(async move {
180            let initial_delay = Duration::from_millis(10);
181            let mut delay = initial_delay;
182            let max_delay = Duration::from_secs(10);
183            loop {
184                match stream.next().await {
185                    Some(Ok(client_config)) => {
186                        this.inner.store(client_config);
187
188                        #[cfg(feature = "tracing")]
189                        debug!("stored updated client config from stream");
190                    }
191                    Some(Err(_)) | None => {
192                        this.stream_healthy.store(false, Ordering::Relaxed);
193
194                        #[cfg(feature = "tracing")]
195                        error!("config stream returned error or none, trying to build new stream");
196
197                        match builder.build().await {
198                            Ok(s) => {
199                                this.stream_healthy.store(true, Ordering::Relaxed);
200                                delay = initial_delay;
201                                stream = s;
202
203                                #[cfg(feature = "tracing")]
204                                info!("reestablished client config stream");
205                            }
206                            Err(err) => {
207                                #[cfg(feature = "tracing")]
208                                error!(retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish client config stream");
209
210                                sleep(delay).await;
211                                delay = (delay * 2).min(max_delay);
212                            }
213                        };
214                    }
215                }
216            }
217        });
218        Ok(ret)
219    }
220
221    /// Returns whether the stream is currently healthy.
222    ///
223    /// This flag is set to `false` when the stream errors or ends, and set
224    /// back to `true` after a successful rebuild.
225    pub fn stream_healthy(&self) -> bool {
226        self.stream_healthy.load(Ordering::Relaxed)
227    }
228
229    /// Get the current [`ClientConfig`].
230    ///
231    /// This is a cheap, lock-free read that loads the internal [`ArcSwap<ClientConfig>`] into an [`Arc<ClientConfig>`]
232    /// Callers can hold onto the returned [`Arc<ClientConfig>`] as long as
233    /// needed; updates will affect future calls, not the already-held value.
234    pub fn get_config(&self) -> Arc<ClientConfig> {
235        self.inner.load_full()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use std::{
242        collections::VecDeque,
243        sync::{
244            Arc,
245            atomic::{AtomicUsize, Ordering},
246        },
247    };
248
249    use rustls::{ClientConfig, RootCertStore};
250    use thiserror::Error;
251    use tokio::sync::{Mutex, mpsc};
252    use tokio_stream::wrappers::ReceiverStream;
253
254    use crate::{ClientConfigProvider, ClientConfigStreamBuilder, ClientConfigStreamError};
255
256    #[derive(Error, Debug)]
257    struct MockError(&'static str);
258    impl std::fmt::Display for MockError {
259        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260            f.write_str(self.0)
261        }
262    }
263    fn empty_client_config() -> Arc<ClientConfig> {
264        Arc::from(
265            ClientConfig::builder()
266                .with_root_certificates(RootCertStore::empty())
267                .with_no_client_auth(),
268        )
269    }
270
271    #[derive(Debug)]
272    struct MockClientConfigStreamBuilder {
273        streams:
274            Mutex<VecDeque<mpsc::Receiver<Result<Arc<ClientConfig>, ClientConfigStreamError>>>>,
275        builds: Arc<AtomicUsize>,
276    }
277
278    impl MockClientConfigStreamBuilder {
279        fn new(
280            streams: Vec<mpsc::Receiver<Result<Arc<ClientConfig>, ClientConfigStreamError>>>,
281        ) -> Self {
282            let builds = Arc::from(AtomicUsize::new(0));
283            let streams = Mutex::from(VecDeque::from(streams));
284            Self { streams, builds }
285        }
286    }
287
288    impl ClientConfigStreamBuilder for MockClientConfigStreamBuilder {
289        type ConfigStream = ReceiverStream<Result<Arc<ClientConfig>, ClientConfigStreamError>>;
290
291        async fn build(&mut self) -> Result<Self::ConfigStream, ClientConfigStreamError> {
292            self.builds.fetch_add(1, Ordering::SeqCst);
293            let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
294                ClientConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
295            })?;
296            Ok(ReceiverStream::new(rx))
297        }
298    }
299
300    #[tokio::test]
301    async fn start_fails_given_initial_stream_build_failure() {
302        let builder = MockClientConfigStreamBuilder::new(vec![]);
303
304        let res = ClientConfigProvider::start(builder).await;
305        match res {
306            Err(ClientConfigStreamError::StreamBuilderError(_)) => { /* test pass */ }
307            _ => panic!("expected ClientConfigStreamError::EmptyStream"),
308        }
309    }
310
311    #[tokio::test]
312    async fn start_fails_when_stream_is_empty() {
313        let (tx, rx) = mpsc::channel(1);
314
315        // drop tx so stream returns Poll::Ready(None)
316        std::mem::drop(tx);
317
318        let builder = MockClientConfigStreamBuilder::new(vec![rx]);
319
320        let res = ClientConfigProvider::start(builder).await;
321        match res {
322            Err(ClientConfigStreamError::EmptyStream) => { /* test pass */ }
323            _ => panic!("expected ClientConfigStreamError::EmptyStream"),
324        }
325    }
326
327    #[tokio::test]
328    async fn start_fails_when_first_result_is_err() {
329        let (tx, rx) = mpsc::channel(1);
330        let builder = MockClientConfigStreamBuilder::new(vec![rx]);
331
332        tx.send(Err(ClientConfigStreamError::StreamError(
333            MockError("fake error").into(),
334        )))
335        .await
336        .unwrap();
337
338        let res = ClientConfigProvider::start(builder).await;
339        match res {
340            Err(ClientConfigStreamError::StreamError(err)) => {
341                assert_eq!(err.to_string(), "fake error");
342            }
343            _ => panic!("expected ClientConfigStreamError::EmptyStream"),
344        }
345    }
346
347    #[tokio::test]
348    async fn start_and_initial_config_is_loaded() {
349        let (tx, rx) = mpsc::channel(1);
350        let builder = MockClientConfigStreamBuilder::new(vec![rx]);
351        let expected = empty_client_config();
352        tx.send(Ok(expected.clone())).await.unwrap();
353        let provider = ClientConfigProvider::start(builder).await.unwrap();
354
355        let got = provider.get_config();
356
357        assert!(Arc::ptr_eq(&got, &expected));
358        assert!(provider.stream_healthy());
359    }
360
361    #[tokio::test]
362    async fn single_stream_config_hot_swap() {
363        let (tx, rx) = mpsc::channel(1);
364        let builder = MockClientConfigStreamBuilder::new(vec![rx]);
365
366        let initial = empty_client_config();
367        tx.send(Ok(initial.clone())).await.unwrap();
368        let provider = ClientConfigProvider::start(builder).await.unwrap();
369        let got = provider.get_config();
370        assert!(Arc::ptr_eq(&got, &initial));
371        assert!(provider.stream_healthy());
372
373        for i in 0..10 {
374            let expected = empty_client_config();
375            tx.send(Ok(expected.clone())).await.unwrap();
376
377            tokio::task::yield_now().await;
378            let got = provider.get_config();
379            assert!(
380                Arc::ptr_eq(&got, &expected),
381                "config not updated on iter {i}"
382            );
383            assert!(provider.stream_healthy());
384        }
385    }
386
387    #[tokio::test]
388    async fn stream_failure_triggers_rebuild() {
389        let (tx1, rx1) = mpsc::channel(1);
390        let (tx2, rx2) = mpsc::channel(1);
391        let builder = MockClientConfigStreamBuilder::new(vec![rx1, rx2]);
392        let builds = &builder.builds.clone();
393        let initial = empty_client_config();
394        tx1.send(Ok(initial.clone())).await.unwrap();
395        let provider = ClientConfigProvider::start(builder).await.unwrap();
396        assert!(Arc::ptr_eq(&provider.get_config(), &initial));
397        assert!(provider.stream_healthy());
398
399        tx1.send(Err(ClientConfigStreamError::StreamError(
400            MockError("fake error").into(),
401        )))
402        .await
403        .unwrap();
404
405        // polling to assert provider.stream_healthy
406        // goes to false proved to be flaky due to it
407        // going back to healthy too fast.
408
409        // check that it rebuilt the stream via the provider
410        tokio::task::yield_now().await;
411        assert_eq!(builds.load(Ordering::SeqCst), 2);
412
413        // push a new config and check that it's loaded
414        let new = empty_client_config();
415        tx2.send(Ok(new.clone())).await.unwrap();
416        tokio::task::yield_now().await;
417
418        // check that stream is healthy and new config was loaded
419        assert!(provider.stream_healthy());
420        assert!(Arc::ptr_eq(&provider.get_config(), &new))
421    }
422
423    #[tokio::test(flavor = "current_thread", start_paused = true)]
424    async fn stream_rebuild_goes_into_backoff() {
425        let (tx, rx) = mpsc::channel(1);
426        let builder = MockClientConfigStreamBuilder::new(vec![rx]);
427        let builds = &builder.builds.clone();
428        let initial = empty_client_config();
429        tx.send(Ok(initial.clone())).await.unwrap();
430        let provider = ClientConfigProvider::start(builder).await.unwrap();
431        assert!(Arc::ptr_eq(&provider.get_config(), &initial));
432        assert!(provider.stream_healthy());
433        assert_eq!(builds.load(Ordering::SeqCst), 1);
434
435        tx.send(Err(ClientConfigStreamError::StreamError(
436            MockError("fake error").into(),
437        )))
438        .await
439        .unwrap();
440        tokio::task::yield_now().await;
441        // assert it tried to rebuild stream but is still unhealthy since
442        // the MockClientConfigBuilder will return an error as the streams
443        // vector is empty.
444        assert_eq!(builds.load(Ordering::SeqCst), 2);
445        assert!(!provider.stream_healthy.load(Ordering::Relaxed));
446    }
447}