scion_sdk_reqwest_connect_rpc/token_source/refresh.rs
1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! [`RefreshTokenSource`] automatically refreshes tokens before expiry using a configurable
15//! [`TokenRefresher`].
16//!
17//! Use the builder pattern to configure refresh intervals, timeouts, and minimum token lifetimes.
18//! See [`RefreshTokenSourceBuilder`] for details.
19
20use std::{
21 sync::Arc,
22 time::{Duration, Instant},
23};
24
25use async_trait::async_trait;
26use tokio::{
27 sync::{Notify, RwLock},
28 task::JoinHandle,
29 time::timeout,
30};
31
32use crate::token_source::{TokenSource, TokenSourceError};
33
34const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
35const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
36const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
37const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
38
39/// Builder for a [RefreshTokenSource].
40pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
41 name: String,
42 token_refresher: T,
43 refresh_retry_delay: Duration,
44 refresh_threshold: Duration,
45 refresh_timeout: Duration,
46 min_token_lifetime: Duration,
47}
48impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
49 /// Creates a new builder for a [RefreshTokenSource].
50 ///
51 /// # Arguments
52 /// * `name` - Name of the token source, used for logging.
53 /// * `token_refresher` - Ability to refresh the token.
54 pub fn new(name: String, token_refresher: T) -> Self {
55 Self {
56 name,
57 token_refresher,
58 refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
59 refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
60 refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
61 min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
62 }
63 }
64
65 /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
66 pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
67 self.min_token_lifetime = duration;
68 self
69 }
70
71 /// The delay between retries if the refresh function fails.
72 pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
73 self.refresh_retry_delay = duration;
74 self
75 }
76
77 /// The duration before the token's expiry when a refresh should be attempted.
78 pub fn refresh_threshold(mut self, duration: Duration) -> Self {
79 self.refresh_threshold = duration;
80 self
81 }
82
83 /// The duration to wait for a refresh to complete when `get_token` is called before a timeout.
84 pub fn refresh_timeout(mut self, duration: Duration) -> Self {
85 self.refresh_timeout = duration;
86 self
87 }
88
89 /// Build the [RefreshTokenSource]
90 pub fn build(self) -> RefreshTokenSource {
91 RefreshTokenSource::new(
92 self.name,
93 self.token_refresher,
94 self.refresh_retry_delay,
95 self.refresh_threshold,
96 self.refresh_timeout,
97 self.min_token_lifetime,
98 )
99 }
100}
101
102// ################################
103// RefreshTokenSource
104
105/// A [TokenSource] automatically refreshing the token before it expires.
106pub struct RefreshTokenSource {
107 /// Shared state between the background refresh task and the token source.
108 result: Arc<RwLock<Option<Result<TokenWithExpiry, TokenSourceError>>>>,
109 /// Notifies waiters when a refresh has completed.
110 refresh_notify: Arc<Notify>,
111 /// The duration to wait for a refresh to complete when `get_token` is called before a timeout
112 /// error is returned.
113 refresh_timeout: Duration,
114 /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
115 min_token_lifetime: Duration,
116 // Handle to manage the background task, ensuring it is aborted when the `RefreshTokenSource`
117 // is dropped.
118 #[allow(unused)]
119 task_handle: RefreshingTokenSourceTaskHandle,
120}
121
122impl RefreshTokenSource {
123 /// Creates a builder for a `RefreshTokenSource`.
124 pub fn builder<T: TokenRefresher>(
125 name: impl Into<String>,
126 token_refresher: T,
127 ) -> RefreshTokenSourceBuilder<T> {
128 RefreshTokenSourceBuilder::new(name.into(), token_refresher)
129 }
130
131 /// Creates a new `RefreshTokenSource`.
132 ///
133 /// # Arguments
134 /// * `name` - Name of the token source, used for logging.
135 /// * `refresh_function` - Function to refresh the token.
136 /// * `refresh_retry_delay` - Delay between retries if the refresh function fails.
137 /// * `refresh_threshold` - Duration before the token's expiry when a refresh should be
138 /// attempted.
139 /// * `refresh_timeout` - Duration to wait for a refresh to complete when `get_token` is called.
140 pub fn new(
141 name: String,
142 token_refresher: impl TokenRefresher,
143 refresh_retry_delay: Duration,
144 refresh_threshold: Duration,
145 refresh_timeout: Duration,
146 min_token_lifetime: Duration,
147 ) -> Self {
148 let refresh_notify = Arc::new(Notify::new());
149 let result = Arc::new(RwLock::new(None));
150 let inner = RefreshTokenSourceTask {
151 name,
152 result: result.clone(),
153 refresh_notify: refresh_notify.clone(),
154 refresh_retry_delay,
155 refresh_threshold,
156 min_token_lifetime,
157 token_refresher: Box::new(token_refresher),
158 };
159
160 let task_handle = inner.run();
161
162 Self {
163 result: result.clone(),
164 refresh_notify,
165 refresh_timeout,
166 min_token_lifetime,
167 task_handle,
168 }
169 }
170}
171
172#[async_trait]
173impl TokenSource for RefreshTokenSource {
174 async fn get_token(&self) -> Result<String, TokenSourceError> {
175 loop {
176 let guard = self.result.read().await;
177
178 match guard.as_ref() {
179 // Return the token if it is still valid for at least `min_token_lifetime`
180 Some(Ok(token))
181 if token.expires_at > (Instant::now() + self.min_token_lifetime) =>
182 {
183 return Ok(token.token.clone());
184 }
185 // If we have an error, return it
186 Some(Err(e)) => {
187 // Stringify the error to avoid lifetime issues
188 return Err(e.to_string().into());
189 }
190 // If we have a expired token or no result wait for a refresh and try again
191 Some(Ok(_)) | None => {
192 let notify = self.refresh_notify.clone();
193 let notified = notify.notified();
194
195 // Must drop after getting a notified to avoid missing a notification
196 drop(guard);
197
198 timeout(self.refresh_timeout, notified)
199 .await
200 .map_err(|_| "timed out waiting for token refresh".to_string())?;
201
202 continue;
203 }
204 }
205 }
206 }
207}
208
209/// A token with its expiry time.
210pub struct TokenWithExpiry {
211 /// JWT string.
212 pub token: String,
213 /// Token expiry.
214 pub expires_at: Instant,
215}
216
217// ################################
218// RefreshingTokenSourceTaskHandle
219
220/// Handle to manage the background refresh task.
221/// When dropped, the task is aborted.
222struct RefreshingTokenSourceTaskHandle {
223 handle: JoinHandle<()>,
224}
225
226impl Drop for RefreshingTokenSourceTaskHandle {
227 fn drop(&mut self) {
228 self.handle.abort();
229 }
230}
231
232struct RefreshTokenSourceTask {
233 name: String,
234 result: Arc<RwLock<Option<Result<TokenWithExpiry, TokenSourceError>>>>,
235 refresh_notify: Arc<Notify>,
236 refresh_retry_delay: Duration,
237 refresh_threshold: Duration,
238 #[allow(clippy::type_complexity)]
239 token_refresher: Box<dyn TokenRefresher>,
240 min_token_lifetime: Duration,
241}
242
243impl RefreshTokenSourceTask {
244 fn run(self) -> RefreshingTokenSourceTaskHandle {
245 let handle = tokio::spawn(async move {
246 let mut fail_count = 0;
247 loop {
248 // Determine when to next refresh the token.
249 let token_expiry = match self.result.read().await.as_ref() {
250 Some(Ok(token)) => token.expires_at,
251 // No token yet, or last refreshes failed, try to get a new token immediately.
252 _ => Instant::now(),
253 };
254
255 let refresh_deadline = token_expiry
256 .checked_sub(self.refresh_threshold)
257 .unwrap_or_else(Instant::now);
258
259 tokio::time::sleep_until(refresh_deadline.into()).await;
260
261 // Attempt to refresh the token.
262 let new_token = self.token_refresher.refresh().await;
263
264 match new_token {
265 // Got a new token, store it and notify waiters
266 Ok(token) => {
267 // Validate that token has a decent expiry time
268 if token.expires_at <= Instant::now() + self.min_token_lifetime {
269 tracing::error!(
270 name=%self.name,
271 token_ttl = ?token.expires_at.saturating_duration_since(Instant::now()),
272 "Refreshed token is already expired or too close to expiry, ignoring",
273 );
274 // XXX(ake): Not sure if we should abort here instead?
275
276 // Wait before trying again to avoid busy looping
277 tokio::time::sleep(self.refresh_retry_delay).await;
278 continue;
279 }
280
281 fail_count = 0;
282 let token_ttl = format!(
283 "{}s",
284 token
285 .expires_at
286 .saturating_duration_since(Instant::now())
287 .as_secs()
288 );
289
290 tracing::info!(
291 name=%self.name,
292 token_ttl,
293 "Refreshed token",
294 );
295
296 // Store the new token and notify waiters
297 {
298 let mut write_guard = self.result.write().await;
299 *write_guard = Some(Ok(token));
300 self.refresh_notify.notify_waiters(); // Must be inside the write lock to avoid missed notifications
301 }
302 }
303 // Failed to refresh the token, log the error and retry after a delay
304 Err(e) => {
305 fail_count += 1;
306 let token_remaining_ttl = token_expiry
307 .saturating_duration_since(Instant::now())
308 .as_secs();
309
310 let token_remaining_ttl = format!("{token_remaining_ttl}s");
311 let next_try = format!("{}s", self.refresh_retry_delay.as_secs());
312
313 tracing::error!(
314 token_remaining_ttl,
315 next_try,
316 fail_count,
317 name= %self.name,
318 "Failed to refresh token: {e}",
319 );
320
321 // If the current token is still valid, keep it, otherwise store the error
322 // and notify waiters
323 if token_expiry <= Instant::now() + self.min_token_lifetime {
324 {
325 let mut write_guard = self.result.write().await;
326 *write_guard = Some(Err(e));
327 self.refresh_notify.notify_waiters(); // Must be inside the write lock to avoid missed notifications
328 }
329 continue;
330 }
331
332 tokio::time::sleep(self.refresh_retry_delay).await;
333 continue;
334 }
335 }
336 }
337 });
338
339 RefreshingTokenSourceTaskHandle { handle }
340 }
341}
342
343// ################################
344// TokenRefresher
345
346/// Anything which allows to refresh a token.
347///
348/// Default implementations are provided for async functions and closures
349#[async_trait]
350pub trait TokenRefresher: Send + Sync + 'static {
351 /// Refreshes the token and return the new token and its expiry time.
352 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
353}
354
355/// Allow any async function or closure matching the signature to be used as a TokenRefresher.
356#[async_trait]
357impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
358where
359 AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
360 FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
361{
362 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
363 (self)().await
364 }
365}