Skip to main content

scion_sdk_reqwest_connect_rpc/token_source/
refresh.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! [`RefreshTokenSource`] automatically refreshes tokens before expiry using a configurable
15//! [`TokenRefresher`].
16//!
17//! Use the builder pattern to configure refresh intervals, timeouts, and minimum token lifetimes.
18//! See [`RefreshTokenSourceBuilder`] for details.
19
20use std::time::{Duration, Instant};
21
22use async_trait::async_trait;
23use tokio::{sync::watch, task::JoinHandle};
24
25use crate::token_source::{TokenSource, TokenSourceError};
26
27const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
28const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
29const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
30const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
31
32/// Builder for a [RefreshTokenSource].
33pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
34    name: String,
35    token_refresher: T,
36    refresh_retry_delay: Duration,
37    refresh_threshold: Duration,
38    refresh_timeout: Duration,
39    min_token_lifetime: Duration,
40    initial_token: Option<TokenWithExpiry>,
41}
42impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
43    /// Creates a new builder for a [RefreshTokenSource].
44    ///
45    /// # Arguments
46    /// * `name` - Name of the token source, used for logging.
47    /// * `token_refresher` - Ability to refresh the token.
48    pub fn new(name: String, token_refresher: T) -> Self {
49        Self {
50            name,
51            token_refresher,
52            refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
53            refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
54            refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
55            min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
56            initial_token: None,
57        }
58    }
59
60    /// Seed the token source with an already-fetched token.
61    ///
62    /// When set, the background task publishes this token immediately on startup
63    /// without calling [`TokenRefresher::refresh`] first.  The normal refresh
64    /// loop then takes over before the token expires.
65    pub fn with_initial_token(mut self, token: TokenWithExpiry) -> Self {
66        self.initial_token = Some(token);
67        self
68    }
69
70    /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
71    pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
72        self.min_token_lifetime = duration;
73        self
74    }
75
76    /// The delay between retries if the refresh function fails.
77    pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
78        self.refresh_retry_delay = duration;
79        self
80    }
81
82    /// The duration before the token's expiry when a refresh should be attempted.
83    pub fn refresh_threshold(mut self, duration: Duration) -> Self {
84        self.refresh_threshold = duration;
85        self
86    }
87
88    /// The duration to wait for a refresh to complete when `get_token` is called before a timeout.
89    pub fn refresh_timeout(mut self, duration: Duration) -> Self {
90        self.refresh_timeout = duration;
91        self
92    }
93
94    /// Build the [RefreshTokenSource]
95    pub fn build(self) -> RefreshTokenSource {
96        RefreshTokenSource::new(
97            self.name,
98            self.token_refresher,
99            self.refresh_retry_delay,
100            self.refresh_threshold,
101            self.min_token_lifetime,
102            self.initial_token,
103        )
104    }
105}
106
107// ################################
108// RefreshTokenSource
109
110/// A [TokenSource] automatically refreshing the token before it expires.
111pub struct RefreshTokenSource {
112    /// Shared state between the background refresh task and the token source.
113    watch_rx: watch::Receiver<Option<Result<String, TokenSourceError>>>,
114    // Handle to manage the background task, ensuring it is aborted when the `RefreshTokenSource`
115    // is dropped.
116    #[allow(unused)]
117    task_handle: RefreshingTokenSourceTaskHandle,
118}
119
120impl RefreshTokenSource {
121    /// Creates a builder for a `RefreshTokenSource`.
122    pub fn builder<T: TokenRefresher>(
123        name: impl Into<String>,
124        token_refresher: T,
125    ) -> RefreshTokenSourceBuilder<T> {
126        RefreshTokenSourceBuilder::new(name.into(), token_refresher)
127    }
128
129    /// Creates a new `RefreshTokenSource`.
130    ///
131    /// # Arguments
132    /// * `name` - Name of the token source, used for logging.
133    /// * `refresh_function` - Function to refresh the token.
134    /// * `refresh_retry_delay` - Delay between retries if the refresh function fails.
135    /// * `refresh_threshold` - Duration before the token's expiry when a refresh should be
136    ///   attempted.
137    /// * `refresh_timeout` - Duration to wait for a refresh to complete when `get_token` is called.
138    /// * `initial_token` - Optional pre-fetched token to publish immediately on startup.
139    pub fn new(
140        name: String,
141        token_refresher: impl TokenRefresher,
142        refresh_retry_delay: Duration,
143        refresh_threshold: Duration,
144        min_token_lifetime: Duration,
145        initial_token: Option<TokenWithExpiry>,
146    ) -> Self {
147        let (watch_tx, watch_rx) = tokio::sync::watch::channel(None);
148        let inner = RefreshTokenSourceTask {
149            name,
150            watch_tx,
151            refresh_retry_delay,
152            refresh_threshold,
153            min_token_lifetime,
154            token_refresher: Box::new(token_refresher),
155            initial_token,
156        };
157
158        let task_handle = inner.run();
159
160        Self {
161            watch_rx,
162            task_handle,
163        }
164    }
165}
166
167#[async_trait]
168impl TokenSource for RefreshTokenSource {
169    fn watch(&self) -> watch::Receiver<Option<Result<String, TokenSourceError>>> {
170        self.watch_rx.clone()
171    }
172}
173
174/// A token with its expiry time.
175#[derive(Clone, Debug)]
176pub struct TokenWithExpiry {
177    /// JWT string.
178    pub token: String,
179    /// Token expiry.
180    pub expires_at: Instant,
181}
182
183// ################################
184// RefreshingTokenSourceTaskHandle
185
186/// Handle to manage the background refresh task.
187/// When dropped, the task is aborted.
188struct RefreshingTokenSourceTaskHandle {
189    handle: JoinHandle<()>,
190}
191
192impl Drop for RefreshingTokenSourceTaskHandle {
193    fn drop(&mut self) {
194        self.handle.abort();
195    }
196}
197
198struct RefreshTokenSourceTask {
199    name: String,
200    watch_tx: watch::Sender<Option<Result<String, TokenSourceError>>>,
201    refresh_retry_delay: Duration,
202    refresh_threshold: Duration,
203    #[allow(clippy::type_complexity)]
204    token_refresher: Box<dyn TokenRefresher>,
205    min_token_lifetime: Duration,
206    initial_token: Option<TokenWithExpiry>,
207}
208
209impl RefreshTokenSourceTask {
210    fn run(self) -> RefreshingTokenSourceTaskHandle {
211        let handle = tokio::spawn(async move {
212            let mut fail_count = 0;
213            // If an initial token was provided, publish it immediately and seed
214            // current_token so that the background loop sleeps until near-expiry.
215            let mut current_token: Option<TokenWithExpiry> = if let Some(tok) = self.initial_token {
216                let token_ttl_secs = tok
217                    .expires_at
218                    .saturating_duration_since(Instant::now())
219                    .as_secs();
220                tracing::debug!(
221                    name = %self.name,
222                    token_ttl_secs,
223                    "Published initial token without calling refresh"
224                );
225                self.watch_tx.send_replace(Some(Ok(tok.token.clone())));
226                Some(tok)
227            } else {
228                None
229            };
230            loop {
231                // Determine when to next refresh the token.
232                let token_expiry = match current_token {
233                    Some(ref token) => token.expires_at,
234                    // No token yet, or last refreshes failed, try to get a new token immediately.
235                    _ => Instant::now(),
236                };
237
238                let refresh_deadline = token_expiry
239                    .checked_sub(self.refresh_threshold)
240                    .unwrap_or_else(Instant::now);
241
242                tokio::time::sleep_until(refresh_deadline.into()).await;
243
244                // Attempt to refresh the token.
245                let new_token = self.token_refresher.refresh().await;
246
247                match new_token {
248                    // Got a new token, store it and notify waiters
249                    Ok(token) => {
250                        let token_ttl_secs = token
251                            .expires_at
252                            .saturating_duration_since(Instant::now())
253                            .as_secs();
254
255                        // Validate that token has a decent expiry time
256                        if token.expires_at <= Instant::now() + self.min_token_lifetime {
257                            tracing::error!(
258                                name = %self.name,
259                                token_ttl_secs,
260                                "Refreshed token is already expired or too close to expiry, ignoring"
261                            );
262                            // XXX(ake): Not sure if we should abort here instead?
263
264                            // Wait before trying again to avoid busy looping
265                            tokio::time::sleep(self.refresh_retry_delay).await;
266                            continue;
267                        }
268
269                        fail_count = 0;
270
271                        tracing::info!(
272                            name = %self.name,
273                            token_ttl_secs,
274                            "Refreshed token"
275                        );
276
277                        current_token = Some(token.clone());
278                        self.watch_tx.send_replace(Some(Ok(token.token)));
279                    }
280                    // Failed to refresh the token, log the error and retry after a delay
281                    Err(e) => {
282                        fail_count += 1;
283
284                        tracing::error!(
285                            name = %self.name,
286                            ttl_secs = token_expiry.saturating_duration_since(Instant::now()).as_secs(),
287                            retry_secs = self.refresh_retry_delay.as_secs(),
288                            fail_count,
289                            error = %e,
290                            "Failed to refresh token"
291                        );
292
293                        // If the current token is still valid, keep it, otherwise store the error
294                        // and notify waiters
295                        if token_expiry <= Instant::now() + self.min_token_lifetime {
296                            current_token = None;
297                            self.watch_tx.send_replace(Some(Err(e)));
298                            continue;
299                        }
300
301                        tokio::time::sleep(self.refresh_retry_delay).await;
302                        continue;
303                    }
304                }
305            }
306        });
307
308        RefreshingTokenSourceTaskHandle { handle }
309    }
310}
311
312// ################################
313// TokenRefresher
314
315/// Anything which allows to refresh a token.
316///
317/// Default implementations are provided for async functions and closures
318#[async_trait]
319pub trait TokenRefresher: Send + Sync + 'static {
320    /// Refreshes the token and return the new token and its expiry time.
321    async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
322}
323
324/// Allow any async function or closure matching the signature to be used as a TokenRefresher.
325#[async_trait]
326impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
327where
328    AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
329    FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
330{
331    async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
332        (self)().await
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use std::{
339        sync::{
340            Arc,
341            atomic::{AtomicUsize, Ordering},
342        },
343        time::{Duration, Instant},
344    };
345
346    use tokio::sync::Notify;
347
348    use super::*;
349
350    #[tokio::test]
351    async fn initial_token_is_published_without_calling_refresh() {
352        let refresh_count = Arc::new(AtomicUsize::new(0));
353        let refresh_count_clone = Arc::clone(&refresh_count);
354
355        let source = RefreshTokenSource::builder("test", move || {
356            refresh_count_clone.fetch_add(1, Ordering::SeqCst);
357            let token = TokenWithExpiry {
358                token: "refreshed-token".to_string(),
359                expires_at: Instant::now() + Duration::from_secs(3600),
360            };
361            async move { Ok::<_, TokenSourceError>(token) }
362        })
363        .with_initial_token(TokenWithExpiry {
364            token: "initial-token".to_string(),
365            expires_at: Instant::now() + Duration::from_secs(3600),
366        })
367        // Use a large threshold so even with the 1h expiry the refresh fires far
368        // in the future (expires_at - threshold ≈ 60 min - 60 s = ~59 min away).
369        .refresh_threshold(Duration::from_secs(60))
370        .build();
371
372        // Yield once so the background task runs to its first sleep_until.
373        tokio::task::yield_now().await;
374
375        // The watch should already hold the initial token.
376        let mut rx = source.watch();
377        let borrow = rx.borrow_and_update();
378        match borrow.as_ref() {
379            Some(Ok(token)) => assert_eq!(token, "initial-token"),
380            other => panic!("expected initial token, got {other:?}"),
381        }
382        drop(borrow);
383
384        assert_eq!(
385            refresh_count.load(Ordering::SeqCst),
386            0,
387            "refresh() should not be called when an initial token is provided"
388        );
389    }
390
391    #[tokio::test]
392    async fn initial_token_expiry_triggers_refresh() {
393        let notify = Arc::new(Notify::new());
394        let notify_clone = Arc::clone(&notify);
395
396        let _source = RefreshTokenSource::builder("test", move || {
397            notify_clone.notify_one();
398            let token = TokenWithExpiry {
399                token: "refreshed-token".to_string(),
400                expires_at: Instant::now() + Duration::from_secs(3600),
401            };
402            async move { Ok::<_, TokenSourceError>(token) }
403        })
404        .with_initial_token(TokenWithExpiry {
405            token: "initial-token".to_string(),
406            // Expire very soon so the background task wakes up quickly.
407            expires_at: Instant::now() + Duration::from_millis(10),
408        })
409        // Zero threshold: sleep until exactly the expiry instant (10 ms from now).
410        .refresh_threshold(Duration::ZERO)
411        .build();
412
413        // Yield once so the background task starts and sleeps until the 10 ms deadline.
414        tokio::task::yield_now().await;
415
416        // Wait for refresh() to be called.  500 ms gives a 50× margin over the 10 ms sleep.
417        tokio::time::timeout(Duration::from_millis(500), notify.notified())
418            .await
419            .expect("refresh() should be called after the initial token expires");
420        // Assert that the initial token was replaced by the refreshed token.
421        let token = tokio::time::timeout(Duration::from_millis(500), _source.get_token())
422            .await
423            .expect("get_token() should not timeout")
424            .expect("get_token() should succeed");
425        assert_eq!(token, "refreshed-token");
426    }
427}