Skip to main content

scion_sdk_reqwest_connect_rpc/token_source/
refresh.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! [`RefreshTokenSource`] automatically refreshes tokens before expiry using a configurable
15//! [`TokenRefresher`].
16//!
17//! Use the builder pattern to configure refresh intervals, timeouts, and minimum token lifetimes.
18//! See [`RefreshTokenSourceBuilder`] for details.
19
20use std::time::{Duration, Instant};
21
22use async_trait::async_trait;
23use tokio::{sync::watch, task::JoinHandle};
24
25use crate::token_source::{TokenSource, TokenSourceError};
26
27const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
28const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
29const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
30const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
31
32/// Builder for a [RefreshTokenSource].
33pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
34    name: String,
35    token_refresher: T,
36    refresh_retry_delay: Duration,
37    refresh_threshold: Duration,
38    refresh_timeout: Duration,
39    min_token_lifetime: Duration,
40}
41impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
42    /// Creates a new builder for a [RefreshTokenSource].
43    ///
44    /// # Arguments
45    /// * `name` - Name of the token source, used for logging.
46    /// * `token_refresher` - Ability to refresh the token.
47    pub fn new(name: String, token_refresher: T) -> Self {
48        Self {
49            name,
50            token_refresher,
51            refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
52            refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
53            refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
54            min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
55        }
56    }
57
58    /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
59    pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
60        self.min_token_lifetime = duration;
61        self
62    }
63
64    /// The delay between retries if the refresh function fails.
65    pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
66        self.refresh_retry_delay = duration;
67        self
68    }
69
70    /// The duration before the token's expiry when a refresh should be attempted.
71    pub fn refresh_threshold(mut self, duration: Duration) -> Self {
72        self.refresh_threshold = duration;
73        self
74    }
75
76    /// The duration to wait for a refresh to complete when `get_token` is called before a timeout.
77    pub fn refresh_timeout(mut self, duration: Duration) -> Self {
78        self.refresh_timeout = duration;
79        self
80    }
81
82    /// Build the [RefreshTokenSource]
83    pub fn build(self) -> RefreshTokenSource {
84        RefreshTokenSource::new(
85            self.name,
86            self.token_refresher,
87            self.refresh_retry_delay,
88            self.refresh_threshold,
89            self.refresh_timeout,
90        )
91    }
92}
93
94// ################################
95// RefreshTokenSource
96
97/// A [TokenSource] automatically refreshing the token before it expires.
98pub struct RefreshTokenSource {
99    /// Shared state between the background refresh task and the token source.
100    watch_rx: watch::Receiver<Option<Result<String, TokenSourceError>>>,
101    // Handle to manage the background task, ensuring it is aborted when the `RefreshTokenSource`
102    // is dropped.
103    #[allow(unused)]
104    task_handle: RefreshingTokenSourceTaskHandle,
105}
106
107impl RefreshTokenSource {
108    /// Creates a builder for a `RefreshTokenSource`.
109    pub fn builder<T: TokenRefresher>(
110        name: impl Into<String>,
111        token_refresher: T,
112    ) -> RefreshTokenSourceBuilder<T> {
113        RefreshTokenSourceBuilder::new(name.into(), token_refresher)
114    }
115
116    /// Creates a new `RefreshTokenSource`.
117    ///
118    /// # Arguments
119    /// * `name` - Name of the token source, used for logging.
120    /// * `refresh_function` - Function to refresh the token.
121    /// * `refresh_retry_delay` - Delay between retries if the refresh function fails.
122    /// * `refresh_threshold` - Duration before the token's expiry when a refresh should be
123    ///   attempted.
124    /// * `refresh_timeout` - Duration to wait for a refresh to complete when `get_token` is called.
125    pub fn new(
126        name: String,
127        token_refresher: impl TokenRefresher,
128        refresh_retry_delay: Duration,
129        refresh_threshold: Duration,
130        min_token_lifetime: Duration,
131    ) -> Self {
132        let (watch_tx, watch_rx) = tokio::sync::watch::channel(None);
133        let inner = RefreshTokenSourceTask {
134            name,
135            watch_tx,
136            refresh_retry_delay,
137            refresh_threshold,
138            min_token_lifetime,
139            token_refresher: Box::new(token_refresher),
140        };
141
142        let task_handle = inner.run();
143
144        Self {
145            watch_rx,
146            task_handle,
147        }
148    }
149}
150
151#[async_trait]
152impl TokenSource for RefreshTokenSource {
153    fn watch(&self) -> watch::Receiver<Option<Result<String, TokenSourceError>>> {
154        self.watch_rx.clone()
155    }
156}
157
158/// A token with its expiry time.
159#[derive(Clone, Debug)]
160pub struct TokenWithExpiry {
161    /// JWT string.
162    pub token: String,
163    /// Token expiry.
164    pub expires_at: Instant,
165}
166
167// ################################
168// RefreshingTokenSourceTaskHandle
169
170/// Handle to manage the background refresh task.
171/// When dropped, the task is aborted.
172struct RefreshingTokenSourceTaskHandle {
173    handle: JoinHandle<()>,
174}
175
176impl Drop for RefreshingTokenSourceTaskHandle {
177    fn drop(&mut self) {
178        self.handle.abort();
179    }
180}
181
182struct RefreshTokenSourceTask {
183    name: String,
184    watch_tx: watch::Sender<Option<Result<String, TokenSourceError>>>,
185    refresh_retry_delay: Duration,
186    refresh_threshold: Duration,
187    #[allow(clippy::type_complexity)]
188    token_refresher: Box<dyn TokenRefresher>,
189    min_token_lifetime: Duration,
190}
191
192impl RefreshTokenSourceTask {
193    fn run(self) -> RefreshingTokenSourceTaskHandle {
194        let handle = tokio::spawn(async move {
195            let mut fail_count = 0;
196            let mut current_token: Option<TokenWithExpiry> = None;
197            loop {
198                // Determine when to next refresh the token.
199                let token_expiry = match current_token {
200                    Some(ref token) => token.expires_at,
201                    // No token yet, or last refreshes failed, try to get a new token immediately.
202                    _ => Instant::now(),
203                };
204
205                let refresh_deadline = token_expiry
206                    .checked_sub(self.refresh_threshold)
207                    .unwrap_or_else(Instant::now);
208
209                tokio::time::sleep_until(refresh_deadline.into()).await;
210
211                // Attempt to refresh the token.
212                let new_token = self.token_refresher.refresh().await;
213
214                match new_token {
215                    // Got a new token, store it and notify waiters
216                    Ok(token) => {
217                        let token_ttl_secs = token
218                            .expires_at
219                            .saturating_duration_since(Instant::now())
220                            .as_secs();
221
222                        // Validate that token has a decent expiry time
223                        if token.expires_at <= Instant::now() + self.min_token_lifetime {
224                            tracing::error!(
225                                name = %self.name,
226                                token_ttl_secs,
227                                "Refreshed token is already expired or too close to expiry, ignoring"
228                            );
229                            // XXX(ake): Not sure if we should abort here instead?
230
231                            // Wait before trying again to avoid busy looping
232                            tokio::time::sleep(self.refresh_retry_delay).await;
233                            continue;
234                        }
235
236                        fail_count = 0;
237
238                        tracing::info!(
239                            name = %self.name,
240                            token_ttl_secs,
241                            "Refreshed token"
242                        );
243
244                        current_token = Some(token.clone());
245                        self.watch_tx.send_replace(Some(Ok(token.token)));
246                    }
247                    // Failed to refresh the token, log the error and retry after a delay
248                    Err(e) => {
249                        fail_count += 1;
250
251                        tracing::error!(
252                            name = %self.name,
253                            ttl_secs = token_expiry.saturating_duration_since(Instant::now()).as_secs(),
254                            retry_secs = self.refresh_retry_delay.as_secs(),
255                            fail_count,
256                            error = %e,
257                            "Failed to refresh token"
258                        );
259
260                        // If the current token is still valid, keep it, otherwise store the error
261                        // and notify waiters
262                        if token_expiry <= Instant::now() + self.min_token_lifetime {
263                            current_token = None;
264                            self.watch_tx.send_replace(Some(Err(e)));
265                            continue;
266                        }
267
268                        tokio::time::sleep(self.refresh_retry_delay).await;
269                        continue;
270                    }
271                }
272            }
273        });
274
275        RefreshingTokenSourceTaskHandle { handle }
276    }
277}
278
279// ################################
280// TokenRefresher
281
282/// Anything which allows to refresh a token.
283///
284/// Default implementations are provided for async functions and closures
285#[async_trait]
286pub trait TokenRefresher: Send + Sync + 'static {
287    /// Refreshes the token and return the new token and its expiry time.
288    async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
289}
290
291/// Allow any async function or closure matching the signature to be used as a TokenRefresher.
292#[async_trait]
293impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
294where
295    AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
296    FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
297{
298    async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
299        (self)().await
300    }
301}