scion_sdk_reqwest_connect_rpc/token_source/refresh.rs
1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! [`RefreshTokenSource`] automatically refreshes tokens before expiry using a configurable
15//! [`TokenRefresher`].
16//!
17//! Use the builder pattern to configure refresh intervals, timeouts, and minimum token lifetimes.
18//! See [`RefreshTokenSourceBuilder`] for details.
19
20use std::{
21 sync::Arc,
22 time::{Duration, Instant},
23};
24
25use async_trait::async_trait;
26use tokio::{
27 sync::{Notify, RwLock},
28 task::JoinHandle,
29 time::timeout,
30};
31
32use crate::token_source::{TokenSource, TokenSourceError};
33
34const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
35const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
36const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
37const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
38
39/// Builder for a [RefreshTokenSource].
40pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
41 name: String,
42 token_refresher: T,
43 refresh_retry_delay: Duration,
44 refresh_threshold: Duration,
45 refresh_timeout: Duration,
46 min_token_lifetime: Duration,
47}
48impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
49 /// Creates a new builder for a [RefreshTokenSource].
50 ///
51 /// # Arguments
52 /// * `name` - Name of the token source, used for logging.
53 /// * `token_refresher` - Ability to refresh the token.
54 pub fn new(name: String, token_refresher: T) -> Self {
55 Self {
56 name,
57 token_refresher,
58 refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
59 refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
60 refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
61 min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
62 }
63 }
64
65 /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
66 pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
67 self.min_token_lifetime = duration;
68 self
69 }
70
71 /// The delay between retries if the refresh function fails.
72 pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
73 self.refresh_retry_delay = duration;
74 self
75 }
76
77 /// The duration before the token's expiry when a refresh should be attempted.
78 pub fn refresh_threshold(mut self, duration: Duration) -> Self {
79 self.refresh_threshold = duration;
80 self
81 }
82
83 /// The duration to wait for a refresh to complete when `get_token` is called before a timeout.
84 pub fn refresh_timeout(mut self, duration: Duration) -> Self {
85 self.refresh_timeout = duration;
86 self
87 }
88
89 /// Build the [RefreshTokenSource]
90 pub fn build(self) -> RefreshTokenSource {
91 RefreshTokenSource::new(
92 self.name,
93 self.token_refresher,
94 self.refresh_retry_delay,
95 self.refresh_threshold,
96 self.refresh_timeout,
97 self.min_token_lifetime,
98 )
99 }
100}
101
102// ################################
103// RefreshTokenSource
104
105/// A [TokenSource] automatically refreshing the token before it expires.
106pub struct RefreshTokenSource {
107 /// Shared state between the background refresh task and the token source.
108 result: Arc<RwLock<Option<Result<TokenWithExpiry, TokenSourceError>>>>,
109 /// Notifies waiters when a refresh has completed.
110 refresh_notify: Arc<Notify>,
111 /// The duration to wait for a refresh to complete when `get_token` is called before a timeout
112 /// error is returned.
113 refresh_timeout: Duration,
114 /// Minimum lifetime a token must have to be considered valid when returned by `get_token`.
115 min_token_lifetime: Duration,
116 // Handle to manage the background task, ensuring it is aborted when the `RefreshTokenSource`
117 // is dropped.
118 #[allow(unused)]
119 task_handle: RefreshingTokenSourceTaskHandle,
120}
121
122impl RefreshTokenSource {
123 /// Creates a builder for a `RefreshTokenSource`.
124 pub fn builder<T: TokenRefresher>(
125 name: impl Into<String>,
126 token_refresher: T,
127 ) -> RefreshTokenSourceBuilder<T> {
128 RefreshTokenSourceBuilder::new(name.into(), token_refresher)
129 }
130
131 /// Creates a new `RefreshTokenSource`.
132 ///
133 /// # Arguments
134 /// * `name` - Name of the token source, used for logging.
135 /// * `refresh_function` - Function to refresh the token.
136 /// * `refresh_retry_delay` - Delay between retries if the refresh function fails.
137 /// * `refresh_threshold` - Duration before the token's expiry when a refresh should be
138 /// attempted.
139 /// * `refresh_timeout` - Duration to wait for a refresh to complete when `get_token` is called.
140 pub fn new(
141 name: String,
142 token_refresher: impl TokenRefresher,
143 refresh_retry_delay: Duration,
144 refresh_threshold: Duration,
145 refresh_timeout: Duration,
146 min_token_lifetime: Duration,
147 ) -> Self {
148 let refresh_notify = Arc::new(Notify::new());
149 let result = Arc::new(RwLock::new(None));
150 let inner = RefreshTokenSourceTask {
151 name,
152 result: result.clone(),
153 refresh_notify: refresh_notify.clone(),
154 refresh_retry_delay,
155 refresh_threshold,
156 min_token_lifetime,
157 token_refresher: Box::new(token_refresher),
158 };
159
160 let task_handle = inner.run();
161
162 Self {
163 result: result.clone(),
164 refresh_notify,
165 refresh_timeout,
166 min_token_lifetime,
167 task_handle,
168 }
169 }
170}
171
172#[async_trait]
173impl TokenSource for RefreshTokenSource {
174 async fn get_token(&self) -> Result<String, TokenSourceError> {
175 loop {
176 let guard = self.result.read().await;
177
178 match guard.as_ref() {
179 // Return the token if it is still valid for at least `min_token_lifetime`
180 Some(Ok(token))
181 if token.expires_at > (Instant::now() + self.min_token_lifetime) =>
182 {
183 return Ok(token.token.clone());
184 }
185 // If we have an error, return it
186 Some(Err(e)) => {
187 // Stringify the error to avoid lifetime issues
188 return Err(e.to_string().into());
189 }
190 // If we have a expired token or no result wait for a refresh and try again
191 Some(Ok(_)) | None => {
192 let notify = self.refresh_notify.clone();
193 let notified = notify.notified();
194
195 // Must drop after getting a notified to avoid missing a notification
196 drop(guard);
197
198 timeout(self.refresh_timeout, notified)
199 .await
200 .map_err(|_| "timed out waiting for token refresh".to_string())?;
201
202 continue;
203 }
204 }
205 }
206 }
207}
208
209/// A token with its expiry time.
210pub struct TokenWithExpiry {
211 /// JWT string.
212 pub token: String,
213 /// Token expiry.
214 pub expires_at: Instant,
215}
216
217// ################################
218// RefreshingTokenSourceTaskHandle
219
220/// Handle to manage the background refresh task.
221/// When dropped, the task is aborted.
222struct RefreshingTokenSourceTaskHandle {
223 handle: JoinHandle<()>,
224}
225
226impl Drop for RefreshingTokenSourceTaskHandle {
227 fn drop(&mut self) {
228 self.handle.abort();
229 }
230}
231
232struct RefreshTokenSourceTask {
233 name: String,
234 result: Arc<RwLock<Option<Result<TokenWithExpiry, TokenSourceError>>>>,
235 refresh_notify: Arc<Notify>,
236 refresh_retry_delay: Duration,
237 refresh_threshold: Duration,
238 #[allow(clippy::type_complexity)]
239 token_refresher: Box<dyn TokenRefresher>,
240 min_token_lifetime: Duration,
241}
242
243impl RefreshTokenSourceTask {
244 fn run(self) -> RefreshingTokenSourceTaskHandle {
245 let handle = tokio::spawn(async move {
246 let mut fail_count = 0;
247 loop {
248 // Determine when to next refresh the token.
249 let token_expiry = match self.result.read().await.as_ref() {
250 Some(Ok(token)) => token.expires_at,
251 // No token yet, or last refreshes failed, try to get a new token immediately.
252 _ => Instant::now(),
253 };
254
255 let refresh_deadline = token_expiry
256 .checked_sub(self.refresh_threshold)
257 .unwrap_or_else(Instant::now);
258
259 tokio::time::sleep_until(refresh_deadline.into()).await;
260
261 // Attempt to refresh the token.
262 let new_token = self.token_refresher.refresh().await;
263
264 match new_token {
265 // Got a new token, store it and notify waiters
266 Ok(token) => {
267 let token_ttl_secs = token
268 .expires_at
269 .saturating_duration_since(Instant::now())
270 .as_secs();
271
272 // Validate that token has a decent expiry time
273 if token.expires_at <= Instant::now() + self.min_token_lifetime {
274 tracing::error!(
275 name = %self.name,
276 token_ttl_secs,
277 "Refreshed token is already expired or too close to expiry, ignoring"
278 );
279 // XXX(ake): Not sure if we should abort here instead?
280
281 // Wait before trying again to avoid busy looping
282 tokio::time::sleep(self.refresh_retry_delay).await;
283 continue;
284 }
285
286 fail_count = 0;
287
288 tracing::info!(
289 name = %self.name,
290 token_ttl_secs,
291 "Refreshed token"
292 );
293
294 // Store the new token and notify waiters
295 {
296 let mut write_guard = self.result.write().await;
297 *write_guard = Some(Ok(token));
298 self.refresh_notify.notify_waiters(); // Must be inside the write lock to avoid missed notifications
299 }
300 }
301 // Failed to refresh the token, log the error and retry after a delay
302 Err(e) => {
303 fail_count += 1;
304
305 tracing::error!(
306 name = %self.name,
307 ttl_secs = token_expiry.saturating_duration_since(Instant::now()).as_secs(),
308 retry_secs = self.refresh_retry_delay.as_secs(),
309 fail_count,
310 error = %e,
311 "Failed to refresh token"
312 );
313
314 // If the current token is still valid, keep it, otherwise store the error
315 // and notify waiters
316 if token_expiry <= Instant::now() + self.min_token_lifetime {
317 {
318 let mut write_guard = self.result.write().await;
319 *write_guard = Some(Err(e));
320 self.refresh_notify.notify_waiters(); // Must be inside the write lock to avoid missed notifications
321 }
322 continue;
323 }
324
325 tokio::time::sleep(self.refresh_retry_delay).await;
326 continue;
327 }
328 }
329 }
330 });
331
332 RefreshingTokenSourceTaskHandle { handle }
333 }
334}
335
336// ################################
337// TokenRefresher
338
339/// Anything which allows to refresh a token.
340///
341/// Default implementations are provided for async functions and closures
342#[async_trait]
343pub trait TokenRefresher: Send + Sync + 'static {
344 /// Refreshes the token and return the new token and its expiry time.
345 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
346}
347
348/// Allow any async function or closure matching the signature to be used as a TokenRefresher.
349#[async_trait]
350impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
351where
352 AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
353 FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
354{
355 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
356 (self)().await
357 }
358}