scion_sdk_reqwest_connect_rpc/token_source/
refresh.rs1use std::time::{Duration, Instant};
21
22use async_trait::async_trait;
23use tokio::{sync::watch, task::JoinHandle};
24
25use crate::token_source::{TokenSource, TokenSourceError};
26
27const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
28const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
29const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
30const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
31
32pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
34 name: String,
35 token_refresher: T,
36 refresh_retry_delay: Duration,
37 refresh_threshold: Duration,
38 refresh_timeout: Duration,
39 min_token_lifetime: Duration,
40 initial_token: Option<TokenWithExpiry>,
41}
42impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
43 pub fn new(name: String, token_refresher: T) -> Self {
49 Self {
50 name,
51 token_refresher,
52 refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
53 refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
54 refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
55 min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
56 initial_token: None,
57 }
58 }
59
60 pub fn with_initial_token(mut self, token: TokenWithExpiry) -> Self {
66 self.initial_token = Some(token);
67 self
68 }
69
70 pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
72 self.min_token_lifetime = duration;
73 self
74 }
75
76 pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
78 self.refresh_retry_delay = duration;
79 self
80 }
81
82 pub fn refresh_threshold(mut self, duration: Duration) -> Self {
84 self.refresh_threshold = duration;
85 self
86 }
87
88 pub fn refresh_timeout(mut self, duration: Duration) -> Self {
90 self.refresh_timeout = duration;
91 self
92 }
93
94 pub fn build(self) -> RefreshTokenSource {
96 RefreshTokenSource::new(
97 self.name,
98 self.token_refresher,
99 self.refresh_retry_delay,
100 self.refresh_threshold,
101 self.min_token_lifetime,
102 self.initial_token,
103 )
104 }
105}
106
107pub struct RefreshTokenSource {
112 watch_rx: watch::Receiver<Option<Result<String, TokenSourceError>>>,
114 #[allow(unused)]
117 task_handle: RefreshingTokenSourceTaskHandle,
118}
119
120impl RefreshTokenSource {
121 pub fn builder<T: TokenRefresher>(
123 name: impl Into<String>,
124 token_refresher: T,
125 ) -> RefreshTokenSourceBuilder<T> {
126 RefreshTokenSourceBuilder::new(name.into(), token_refresher)
127 }
128
129 pub fn new(
140 name: String,
141 token_refresher: impl TokenRefresher,
142 refresh_retry_delay: Duration,
143 refresh_threshold: Duration,
144 min_token_lifetime: Duration,
145 initial_token: Option<TokenWithExpiry>,
146 ) -> Self {
147 let (watch_tx, watch_rx) = tokio::sync::watch::channel(None);
148 let inner = RefreshTokenSourceTask {
149 name,
150 watch_tx,
151 refresh_retry_delay,
152 refresh_threshold,
153 min_token_lifetime,
154 token_refresher: Box::new(token_refresher),
155 initial_token,
156 };
157
158 let task_handle = inner.run();
159
160 Self {
161 watch_rx,
162 task_handle,
163 }
164 }
165}
166
167#[async_trait]
168impl TokenSource for RefreshTokenSource {
169 fn watch(&self) -> watch::Receiver<Option<Result<String, TokenSourceError>>> {
170 self.watch_rx.clone()
171 }
172}
173
174#[derive(Clone, Debug)]
176pub struct TokenWithExpiry {
177 pub token: String,
179 pub expires_at: Instant,
181}
182
183struct RefreshingTokenSourceTaskHandle {
189 handle: JoinHandle<()>,
190}
191
192impl Drop for RefreshingTokenSourceTaskHandle {
193 fn drop(&mut self) {
194 self.handle.abort();
195 }
196}
197
198struct RefreshTokenSourceTask {
199 name: String,
200 watch_tx: watch::Sender<Option<Result<String, TokenSourceError>>>,
201 refresh_retry_delay: Duration,
202 refresh_threshold: Duration,
203 #[allow(clippy::type_complexity)]
204 token_refresher: Box<dyn TokenRefresher>,
205 min_token_lifetime: Duration,
206 initial_token: Option<TokenWithExpiry>,
207}
208
209impl RefreshTokenSourceTask {
210 fn run(self) -> RefreshingTokenSourceTaskHandle {
211 let handle = tokio::spawn(async move {
212 let mut fail_count = 0;
213 let mut current_token: Option<TokenWithExpiry> = if let Some(tok) = self.initial_token {
216 let token_ttl_secs = tok
217 .expires_at
218 .saturating_duration_since(Instant::now())
219 .as_secs();
220 tracing::debug!(
221 name = %self.name,
222 token_ttl_secs,
223 "Published initial token without calling refresh"
224 );
225 self.watch_tx.send_replace(Some(Ok(tok.token.clone())));
226 Some(tok)
227 } else {
228 None
229 };
230 loop {
231 let token_expiry = match current_token {
233 Some(ref token) => token.expires_at,
234 _ => Instant::now(),
236 };
237
238 let refresh_deadline = token_expiry
239 .checked_sub(self.refresh_threshold)
240 .unwrap_or_else(Instant::now);
241
242 tokio::time::sleep_until(refresh_deadline.into()).await;
243
244 let new_token = self.token_refresher.refresh().await;
246
247 match new_token {
248 Ok(token) => {
250 let token_ttl_secs = token
251 .expires_at
252 .saturating_duration_since(Instant::now())
253 .as_secs();
254
255 if token.expires_at <= Instant::now() + self.min_token_lifetime {
257 tracing::error!(
258 name = %self.name,
259 token_ttl_secs,
260 "Refreshed token is already expired or too close to expiry, ignoring"
261 );
262 tokio::time::sleep(self.refresh_retry_delay).await;
266 continue;
267 }
268
269 fail_count = 0;
270
271 tracing::info!(
272 name = %self.name,
273 token_ttl_secs,
274 "Refreshed token"
275 );
276
277 current_token = Some(token.clone());
278 self.watch_tx.send_replace(Some(Ok(token.token)));
279 }
280 Err(e) => {
282 fail_count += 1;
283
284 tracing::error!(
285 name = %self.name,
286 ttl_secs = token_expiry.saturating_duration_since(Instant::now()).as_secs(),
287 retry_secs = self.refresh_retry_delay.as_secs(),
288 fail_count,
289 error = %e,
290 "Failed to refresh token"
291 );
292
293 if token_expiry <= Instant::now() + self.min_token_lifetime {
296 current_token = None;
297 self.watch_tx.send_replace(Some(Err(e)));
298 continue;
299 }
300
301 tokio::time::sleep(self.refresh_retry_delay).await;
302 continue;
303 }
304 }
305 }
306 });
307
308 RefreshingTokenSourceTaskHandle { handle }
309 }
310}
311
312#[async_trait]
319pub trait TokenRefresher: Send + Sync + 'static {
320 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
322}
323
324#[async_trait]
326impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
327where
328 AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
329 FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
330{
331 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
332 (self)().await
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use std::{
339 sync::{
340 Arc,
341 atomic::{AtomicUsize, Ordering},
342 },
343 time::{Duration, Instant},
344 };
345
346 use tokio::sync::Notify;
347
348 use super::*;
349
350 #[tokio::test]
351 async fn initial_token_is_published_without_calling_refresh() {
352 let refresh_count = Arc::new(AtomicUsize::new(0));
353 let refresh_count_clone = Arc::clone(&refresh_count);
354
355 let source = RefreshTokenSource::builder("test", move || {
356 refresh_count_clone.fetch_add(1, Ordering::SeqCst);
357 let token = TokenWithExpiry {
358 token: "refreshed-token".to_string(),
359 expires_at: Instant::now() + Duration::from_secs(3600),
360 };
361 async move { Ok::<_, TokenSourceError>(token) }
362 })
363 .with_initial_token(TokenWithExpiry {
364 token: "initial-token".to_string(),
365 expires_at: Instant::now() + Duration::from_secs(3600),
366 })
367 .refresh_threshold(Duration::from_secs(60))
370 .build();
371
372 tokio::task::yield_now().await;
374
375 let mut rx = source.watch();
377 let borrow = rx.borrow_and_update();
378 match borrow.as_ref() {
379 Some(Ok(token)) => assert_eq!(token, "initial-token"),
380 other => panic!("expected initial token, got {other:?}"),
381 }
382 drop(borrow);
383
384 assert_eq!(
385 refresh_count.load(Ordering::SeqCst),
386 0,
387 "refresh() should not be called when an initial token is provided"
388 );
389 }
390
391 #[tokio::test]
392 async fn initial_token_expiry_triggers_refresh() {
393 let notify = Arc::new(Notify::new());
394 let notify_clone = Arc::clone(¬ify);
395
396 let _source = RefreshTokenSource::builder("test", move || {
397 notify_clone.notify_one();
398 let token = TokenWithExpiry {
399 token: "refreshed-token".to_string(),
400 expires_at: Instant::now() + Duration::from_secs(3600),
401 };
402 async move { Ok::<_, TokenSourceError>(token) }
403 })
404 .with_initial_token(TokenWithExpiry {
405 token: "initial-token".to_string(),
406 expires_at: Instant::now() + Duration::from_millis(10),
408 })
409 .refresh_threshold(Duration::ZERO)
411 .build();
412
413 tokio::task::yield_now().await;
415
416 tokio::time::timeout(Duration::from_millis(500), notify.notified())
418 .await
419 .expect("refresh() should be called after the initial token expires");
420 let token = tokio::time::timeout(Duration::from_millis(500), _source.get_token())
422 .await
423 .expect("get_token() should not timeout")
424 .expect("get_token() should succeed");
425 assert_eq!(token, "refreshed-token");
426 }
427}