scion_sdk_reqwest_connect_rpc/token_source/
refresh.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! [`RefreshTokenSource`] automatically refreshes tokens before expiry using a configurable
15//! [`TokenRefresher`].
16//!
17//! Use the builder pattern to configure refresh intervals, timeouts, and minimum token lifetimes.
18//! See [`RefreshTokenSourceBuilder`] for details.
19
20use std::{
21    sync::Arc,
22    time::{Duration, Instant},
23};
24
25use async_trait::async_trait;
26use tokio::{
27    sync::{Notify, RwLock},
28    task::JoinHandle,
29    time::timeout,
30};
31
32use crate::token_source::{TokenSource, TokenSourceError};
33
34const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
35const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
36const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
37const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
38
39/// Builder for a [RefreshTokenSource].
40pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
41    name: String,
42    token_refresher: T,
43    refresh_retry_delay: Duration,
44    refresh_threshold: Duration,
45    refresh_timeout: Duration,
46    min_token_lifetime: Duration,
47}
48impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
49    /// Creates a new builder for a [RefreshTokenSource].
50    ///
51    /// # Arguments
52    /// * `name` - Name of the token source, used for logging.
53    /// * `token_refresher` - Ability to refresh the token.
54    pub fn new(name: String, token_refresher: T) -> Self {
55        Self {
56            name,
57            token_refresher,
58            refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
59            refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
60            refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
61            min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
62        }
63    }
64
65    /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
66    pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
67        self.min_token_lifetime = duration;
68        self
69    }
70
71    /// The delay between retries if the refresh function fails.
72    pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
73        self.refresh_retry_delay = duration;
74        self
75    }
76
77    /// The duration before the token's expiry when a refresh should be attempted.
78    pub fn refresh_threshold(mut self, duration: Duration) -> Self {
79        self.refresh_threshold = duration;
80        self
81    }
82
83    /// The duration to wait for a refresh to complete when `get_token` is called before a timeout.
84    pub fn refresh_timeout(mut self, duration: Duration) -> Self {
85        self.refresh_timeout = duration;
86        self
87    }
88
89    /// Build the [RefreshTokenSource]
90    pub fn build(self) -> RefreshTokenSource {
91        RefreshTokenSource::new(
92            self.name,
93            self.token_refresher,
94            self.refresh_retry_delay,
95            self.refresh_threshold,
96            self.refresh_timeout,
97            self.min_token_lifetime,
98        )
99    }
100}
101
102// ################################
103// RefreshTokenSource
104
105/// A [TokenSource] automatically refreshing the token before it expires.
106pub struct RefreshTokenSource {
107    /// Shared state between the background refresh task and the token source.
108    result: Arc<RwLock<Option<Result<TokenWithExpiry, TokenSourceError>>>>,
109    /// Notifies waiters when a refresh has completed.
110    refresh_notify: Arc<Notify>,
111    /// The duration to wait for a refresh to complete when `get_token` is called before a timeout
112    /// error is returned.
113    refresh_timeout: Duration,
114    /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
115    min_token_lifetime: Duration,
116    // Handle to manage the background task, ensuring it is aborted when the `RefreshTokenSource`
117    // is dropped.
118    #[allow(unused)]
119    task_handle: RefreshingTokenSourceTaskHandle,
120}
121
122impl RefreshTokenSource {
123    /// Creates a builder for a `RefreshTokenSource`.
124    pub fn builder<T: TokenRefresher>(
125        name: impl Into<String>,
126        token_refresher: T,
127    ) -> RefreshTokenSourceBuilder<T> {
128        RefreshTokenSourceBuilder::new(name.into(), token_refresher)
129    }
130
131    /// Creates a new `RefreshTokenSource`.
132    ///
133    /// # Arguments
134    /// * `name` - Name of the token source, used for logging.
135    /// * `refresh_function` - Function to refresh the token.
136    /// * `refresh_retry_delay` - Delay between retries if the refresh function fails.
137    /// * `refresh_threshold` - Duration before the token's expiry when a refresh should be
138    ///   attempted.
139    /// * `refresh_timeout` - Duration to wait for a refresh to complete when `get_token` is called.
140    pub fn new(
141        name: String,
142        token_refresher: impl TokenRefresher,
143        refresh_retry_delay: Duration,
144        refresh_threshold: Duration,
145        refresh_timeout: Duration,
146        min_token_lifetime: Duration,
147    ) -> Self {
148        let refresh_notify = Arc::new(Notify::new());
149        let result = Arc::new(RwLock::new(None));
150        let inner = RefreshTokenSourceTask {
151            name,
152            result: result.clone(),
153            refresh_notify: refresh_notify.clone(),
154            refresh_retry_delay,
155            refresh_threshold,
156            min_token_lifetime,
157            token_refresher: Box::new(token_refresher),
158        };
159
160        let task_handle = inner.run();
161
162        Self {
163            result: result.clone(),
164            refresh_notify,
165            refresh_timeout,
166            min_token_lifetime,
167            task_handle,
168        }
169    }
170}
171
172#[async_trait]
173impl TokenSource for RefreshTokenSource {
174    async fn get_token(&self) -> Result<String, TokenSourceError> {
175        loop {
176            let guard = self.result.read().await;
177
178            match guard.as_ref() {
179                // Return the token if it is still valid for at least `min_token_lifetime`
180                Some(Ok(token))
181                    if token.expires_at > (Instant::now() + self.min_token_lifetime) =>
182                {
183                    return Ok(token.token.clone());
184                }
185                // If we have an error, return it
186                Some(Err(e)) => {
187                    // Stringify the error to avoid lifetime issues
188                    return Err(e.to_string().into());
189                }
190                // If we have a expired token or no result wait for a refresh and try again
191                Some(Ok(_)) | None => {
192                    let notify = self.refresh_notify.clone();
193                    let notified = notify.notified();
194
195                    // Must drop after getting a notified to avoid missing a notification
196                    drop(guard);
197
198                    timeout(self.refresh_timeout, notified)
199                        .await
200                        .map_err(|_| "timed out waiting for token refresh".to_string())?;
201
202                    continue;
203                }
204            }
205        }
206    }
207}
208
209/// A token with its expiry time.
210pub struct TokenWithExpiry {
211    /// JWT string.
212    pub token: String,
213    /// Token expiry.
214    pub expires_at: Instant,
215}
216
217// ################################
218// RefreshingTokenSourceTaskHandle
219
220/// Handle to manage the background refresh task.
221/// When dropped, the task is aborted.
222struct RefreshingTokenSourceTaskHandle {
223    handle: JoinHandle<()>,
224}
225
226impl Drop for RefreshingTokenSourceTaskHandle {
227    fn drop(&mut self) {
228        self.handle.abort();
229    }
230}
231
232struct RefreshTokenSourceTask {
233    name: String,
234    result: Arc<RwLock<Option<Result<TokenWithExpiry, TokenSourceError>>>>,
235    refresh_notify: Arc<Notify>,
236    refresh_retry_delay: Duration,
237    refresh_threshold: Duration,
238    #[allow(clippy::type_complexity)]
239    token_refresher: Box<dyn TokenRefresher>,
240    min_token_lifetime: Duration,
241}
242
243impl RefreshTokenSourceTask {
244    fn run(self) -> RefreshingTokenSourceTaskHandle {
245        let handle = tokio::spawn(async move {
246            let mut fail_count = 0;
247            loop {
248                // Determine when to next refresh the token.
249                let token_expiry = match self.result.read().await.as_ref() {
250                    Some(Ok(token)) => token.expires_at,
251                    // No token yet, or last refreshes failed, try to get a new token immediately.
252                    _ => Instant::now(),
253                };
254
255                let refresh_deadline = token_expiry
256                    .checked_sub(self.refresh_threshold)
257                    .unwrap_or_else(Instant::now);
258
259                tokio::time::sleep_until(refresh_deadline.into()).await;
260
261                // Attempt to refresh the token.
262                let new_token = self.token_refresher.refresh().await;
263
264                match new_token {
265                    // Got a new token, store it and notify waiters
266                    Ok(token) => {
267                        let token_ttl_secs = token
268                            .expires_at
269                            .saturating_duration_since(Instant::now())
270                            .as_secs();
271
272                        // Validate that token has a decent expiry time
273                        if token.expires_at <= Instant::now() + self.min_token_lifetime {
274                            tracing::error!(
275                                name = %self.name,
276                                token_ttl_secs,
277                                "Refreshed token is already expired or too close to expiry, ignoring"
278                            );
279                            // XXX(ake): Not sure if we should abort here instead?
280
281                            // Wait before trying again to avoid busy looping
282                            tokio::time::sleep(self.refresh_retry_delay).await;
283                            continue;
284                        }
285
286                        fail_count = 0;
287
288                        tracing::info!(
289                            name = %self.name,
290                            token_ttl_secs,
291                            "Refreshed token"
292                        );
293
294                        // Store the new token and notify waiters
295                        {
296                            let mut write_guard = self.result.write().await;
297                            *write_guard = Some(Ok(token));
298                            self.refresh_notify.notify_waiters(); // Must be inside the write lock to avoid missed notifications
299                        }
300                    }
301                    // Failed to refresh the token, log the error and retry after a delay
302                    Err(e) => {
303                        fail_count += 1;
304
305                        tracing::error!(
306                            name = %self.name,
307                            ttl_secs = token_expiry.saturating_duration_since(Instant::now()).as_secs(),
308                            retry_secs = self.refresh_retry_delay.as_secs(),
309                            fail_count,
310                            error = %e,
311                            "Failed to refresh token"
312                        );
313
314                        // If the current token is still valid, keep it, otherwise store the error
315                        // and notify waiters
316                        if token_expiry <= Instant::now() + self.min_token_lifetime {
317                            {
318                                let mut write_guard = self.result.write().await;
319                                *write_guard = Some(Err(e));
320                                self.refresh_notify.notify_waiters(); // Must be inside the write lock to avoid missed notifications
321                            }
322                            continue;
323                        }
324
325                        tokio::time::sleep(self.refresh_retry_delay).await;
326                        continue;
327                    }
328                }
329            }
330        });
331
332        RefreshingTokenSourceTaskHandle { handle }
333    }
334}
335
336// ################################
337// TokenRefresher
338
339/// Anything which allows to refresh a token.
340///
341/// Default implementations are provided for async functions and closures
342#[async_trait]
343pub trait TokenRefresher: Send + Sync + 'static {
344    /// Refreshes the token and return the new token and its expiry time.
345    async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
346}
347
348/// Allow any async function or closure matching the signature to be used as a TokenRefresher.
349#[async_trait]
350impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
351where
352    AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
353    FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
354{
355    async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
356        (self)().await
357    }
358}