rustls_config_stream/
client.rs1use std::{
4 sync::{
5 Arc,
6 atomic::{AtomicBool, Ordering},
7 },
8 time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use rustls::{ClientConfig, client::VerifierBuilderError};
13use thiserror::Error;
14use tokio::time::sleep;
15use tokio_stream::{Stream, StreamExt};
16
17#[cfg(feature = "tracing")]
18use tracing::{debug, error, info};
19
20#[derive(Debug, Error)]
25pub enum ClientConfigStreamError {
26 #[error("stream provider error")]
30 StreamError(Box<dyn std::error::Error + Send + Sync + 'static>),
31
32 #[error("empty stream")]
37 EmptyStream,
38
39 #[error("could not build stream")]
43 StreamBuilderError(Box<dyn std::error::Error + Send + Sync + 'static>),
44
45 #[error("cert verifier builder error")]
47 VerifierBuilderError(#[from] VerifierBuilderError),
48
49 #[error("missing client certified key")]
51 MissingCertifiedKey,
52
53 #[error("missing root certificates")]
55 MissingRoots,
56
57 #[error("rustls error")]
59 RustlsError(#[from] rustls::Error),
60}
61
62pub trait ClientConfigStreamBuilder {
98 type ConfigStream: Stream<Item = Result<Arc<ClientConfig>, ClientConfigStreamError>>
103 + Send
104 + Sync
105 + Unpin
106 + 'static;
107
108 fn build(
117 &mut self,
118 ) -> impl std::future::Future<Output = Result<Self::ConfigStream, ClientConfigStreamError>> + Send;
119}
120
121pub struct ClientConfigProvider {
141 inner: ArcSwap<ClientConfig>,
143
144 stream_healthy: AtomicBool,
146}
147
148impl ClientConfigProvider {
149 pub async fn start<B>(mut builder: B) -> Result<Arc<Self>, ClientConfigStreamError>
165 where
166 B: ClientConfigStreamBuilder + Send + 'static,
167 {
168 let mut stream = builder.build().await?;
169 let initial = stream
170 .next()
171 .await
172 .ok_or(ClientConfigStreamError::EmptyStream)??;
173 let this = Arc::new(Self {
174 inner: ArcSwap::from(initial),
175 stream_healthy: AtomicBool::new(true),
176 });
177 let ret = this.clone();
178
179 tokio::spawn(async move {
180 let initial_delay = Duration::from_millis(10);
181 let mut delay = initial_delay;
182 let max_delay = Duration::from_secs(10);
183 loop {
184 match stream.next().await {
185 Some(Ok(client_config)) => {
186 this.inner.store(client_config);
187
188 #[cfg(feature = "tracing")]
189 debug!("stored updated client config from stream");
190 }
191 Some(Err(_)) | None => {
192 this.stream_healthy.store(false, Ordering::Relaxed);
193
194 #[cfg(feature = "tracing")]
195 error!("config stream returned error or none, trying to build new stream");
196
197 match builder.build().await {
198 Ok(s) => {
199 this.stream_healthy.store(true, Ordering::Relaxed);
200 delay = initial_delay;
201 stream = s;
202
203 #[cfg(feature = "tracing")]
204 info!("reestablished client config stream");
205 }
206 Err(err) => {
207 #[cfg(feature = "tracing")]
208 error!(retry_in_ms = delay.as_millis(), error = %err, "failed to reestablish client config stream");
209
210 sleep(delay).await;
211 delay = (delay * 2).min(max_delay);
212 }
213 };
214 }
215 }
216 }
217 });
218 Ok(ret)
219 }
220
221 pub fn stream_healthy(&self) -> bool {
226 self.stream_healthy.load(Ordering::Relaxed)
227 }
228
229 pub fn get_config(&self) -> Arc<ClientConfig> {
235 self.inner.load_full()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use std::{
242 collections::VecDeque,
243 sync::{
244 Arc,
245 atomic::{AtomicUsize, Ordering},
246 },
247 };
248
249 use rustls::{ClientConfig, RootCertStore};
250 use thiserror::Error;
251 use tokio::sync::{Mutex, mpsc};
252 use tokio_stream::wrappers::ReceiverStream;
253
254 use crate::{ClientConfigProvider, ClientConfigStreamBuilder, ClientConfigStreamError};
255
256 #[derive(Error, Debug)]
257 struct MockError(&'static str);
258 impl std::fmt::Display for MockError {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.write_str(self.0)
261 }
262 }
263 fn empty_client_config() -> Arc<ClientConfig> {
264 Arc::from(
265 ClientConfig::builder()
266 .with_root_certificates(RootCertStore::empty())
267 .with_no_client_auth(),
268 )
269 }
270
271 #[derive(Debug)]
272 struct MockClientConfigStreamBuilder {
273 streams:
274 Mutex<VecDeque<mpsc::Receiver<Result<Arc<ClientConfig>, ClientConfigStreamError>>>>,
275 builds: Arc<AtomicUsize>,
276 }
277
278 impl MockClientConfigStreamBuilder {
279 fn new(
280 streams: Vec<mpsc::Receiver<Result<Arc<ClientConfig>, ClientConfigStreamError>>>,
281 ) -> Self {
282 let builds = Arc::from(AtomicUsize::new(0));
283 let streams = Mutex::from(VecDeque::from(streams));
284 Self { streams, builds }
285 }
286 }
287
288 impl ClientConfigStreamBuilder for MockClientConfigStreamBuilder {
289 type ConfigStream = ReceiverStream<Result<Arc<ClientConfig>, ClientConfigStreamError>>;
290
291 async fn build(&mut self) -> Result<Self::ConfigStream, ClientConfigStreamError> {
292 self.builds.fetch_add(1, Ordering::SeqCst);
293 let rx = self.streams.lock().await.pop_front().ok_or_else(|| {
294 ClientConfigStreamError::StreamBuilderError(MockError("mock stream error").into())
295 })?;
296 Ok(ReceiverStream::new(rx))
297 }
298 }
299
300 #[tokio::test]
301 async fn start_fails_given_initial_stream_build_failure() {
302 let builder = MockClientConfigStreamBuilder::new(vec![]);
303
304 let res = ClientConfigProvider::start(builder).await;
305 match res {
306 Err(ClientConfigStreamError::StreamBuilderError(_)) => { }
307 _ => panic!("expected ClientConfigStreamError::EmptyStream"),
308 }
309 }
310
311 #[tokio::test]
312 async fn start_fails_when_stream_is_empty() {
313 let (tx, rx) = mpsc::channel(1);
314
315 std::mem::drop(tx);
317
318 let builder = MockClientConfigStreamBuilder::new(vec![rx]);
319
320 let res = ClientConfigProvider::start(builder).await;
321 match res {
322 Err(ClientConfigStreamError::EmptyStream) => { }
323 _ => panic!("expected ClientConfigStreamError::EmptyStream"),
324 }
325 }
326
327 #[tokio::test]
328 async fn start_fails_when_first_result_is_err() {
329 let (tx, rx) = mpsc::channel(1);
330 let builder = MockClientConfigStreamBuilder::new(vec![rx]);
331
332 tx.send(Err(ClientConfigStreamError::StreamError(
333 MockError("fake error").into(),
334 )))
335 .await
336 .unwrap();
337
338 let res = ClientConfigProvider::start(builder).await;
339 match res {
340 Err(ClientConfigStreamError::StreamError(err)) => {
341 assert_eq!(err.to_string(), "fake error");
342 }
343 _ => panic!("expected ClientConfigStreamError::EmptyStream"),
344 }
345 }
346
347 #[tokio::test]
348 async fn start_and_initial_config_is_loaded() {
349 let (tx, rx) = mpsc::channel(1);
350 let builder = MockClientConfigStreamBuilder::new(vec![rx]);
351 let expected = empty_client_config();
352 tx.send(Ok(expected.clone())).await.unwrap();
353 let provider = ClientConfigProvider::start(builder).await.unwrap();
354
355 let got = provider.get_config();
356
357 assert!(Arc::ptr_eq(&got, &expected));
358 assert!(provider.stream_healthy());
359 }
360
361 #[tokio::test]
362 async fn single_stream_config_hot_swap() {
363 let (tx, rx) = mpsc::channel(1);
364 let builder = MockClientConfigStreamBuilder::new(vec![rx]);
365
366 let initial = empty_client_config();
367 tx.send(Ok(initial.clone())).await.unwrap();
368 let provider = ClientConfigProvider::start(builder).await.unwrap();
369 let got = provider.get_config();
370 assert!(Arc::ptr_eq(&got, &initial));
371 assert!(provider.stream_healthy());
372
373 for i in 0..10 {
374 let expected = empty_client_config();
375 tx.send(Ok(expected.clone())).await.unwrap();
376
377 tokio::task::yield_now().await;
378 let got = provider.get_config();
379 assert!(
380 Arc::ptr_eq(&got, &expected),
381 "config not updated on iter {i}"
382 );
383 assert!(provider.stream_healthy());
384 }
385 }
386
387 #[tokio::test]
388 async fn stream_failure_triggers_rebuild() {
389 let (tx1, rx1) = mpsc::channel(1);
390 let (tx2, rx2) = mpsc::channel(1);
391 let builder = MockClientConfigStreamBuilder::new(vec![rx1, rx2]);
392 let builds = &builder.builds.clone();
393 let initial = empty_client_config();
394 tx1.send(Ok(initial.clone())).await.unwrap();
395 let provider = ClientConfigProvider::start(builder).await.unwrap();
396 assert!(Arc::ptr_eq(&provider.get_config(), &initial));
397 assert!(provider.stream_healthy());
398
399 tx1.send(Err(ClientConfigStreamError::StreamError(
400 MockError("fake error").into(),
401 )))
402 .await
403 .unwrap();
404
405 tokio::task::yield_now().await;
411 assert_eq!(builds.load(Ordering::SeqCst), 2);
412
413 let new = empty_client_config();
415 tx2.send(Ok(new.clone())).await.unwrap();
416 tokio::task::yield_now().await;
417
418 assert!(provider.stream_healthy());
420 assert!(Arc::ptr_eq(&provider.get_config(), &new))
421 }
422
423 #[tokio::test(flavor = "current_thread", start_paused = true)]
424 async fn stream_rebuild_goes_into_backoff() {
425 let (tx, rx) = mpsc::channel(1);
426 let builder = MockClientConfigStreamBuilder::new(vec![rx]);
427 let builds = &builder.builds.clone();
428 let initial = empty_client_config();
429 tx.send(Ok(initial.clone())).await.unwrap();
430 let provider = ClientConfigProvider::start(builder).await.unwrap();
431 assert!(Arc::ptr_eq(&provider.get_config(), &initial));
432 assert!(provider.stream_healthy());
433 assert_eq!(builds.load(Ordering::SeqCst), 1);
434
435 tx.send(Err(ClientConfigStreamError::StreamError(
436 MockError("fake error").into(),
437 )))
438 .await
439 .unwrap();
440 tokio::task::yield_now().await;
441 assert_eq!(builds.load(Ordering::SeqCst), 2);
445 assert!(!provider.stream_healthy.load(Ordering::Relaxed));
446 }
447}