scion_sdk_reqwest_connect_rpc/token_source/
refresh.rs1use std::time::{Duration, Instant};
21
22use async_trait::async_trait;
23use tokio::{sync::watch, task::JoinHandle};
24
25use crate::token_source::{TokenSource, TokenSourceError};
26
27const DEFAULT_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5);
28const DEFAULT_REFRESH_THRESHOLD: Duration = Duration::from_secs(60);
29const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
30const DEFAULT_MIN_TOKEN_LIFETIME: Duration = Duration::from_secs(10);
31
32pub struct RefreshTokenSourceBuilder<T: TokenRefresher> {
34 name: String,
35 token_refresher: T,
36 refresh_retry_delay: Duration,
37 refresh_threshold: Duration,
38 refresh_timeout: Duration,
39 min_token_lifetime: Duration,
40}
41impl<T: TokenRefresher> RefreshTokenSourceBuilder<T> {
42 pub fn new(name: String, token_refresher: T) -> Self {
48 Self {
49 name,
50 token_refresher,
51 refresh_retry_delay: DEFAULT_REFRESH_RETRY_DELAY,
52 refresh_threshold: DEFAULT_REFRESH_THRESHOLD,
53 refresh_timeout: DEFAULT_REFRESH_TIMEOUT,
54 min_token_lifetime: DEFAULT_MIN_TOKEN_LIFETIME,
55 }
56 }
57
58 pub fn min_token_lifetime(mut self, duration: Duration) -> Self {
60 self.min_token_lifetime = duration;
61 self
62 }
63
64 pub fn refresh_retry_delay(mut self, duration: Duration) -> Self {
66 self.refresh_retry_delay = duration;
67 self
68 }
69
70 pub fn refresh_threshold(mut self, duration: Duration) -> Self {
72 self.refresh_threshold = duration;
73 self
74 }
75
76 pub fn refresh_timeout(mut self, duration: Duration) -> Self {
78 self.refresh_timeout = duration;
79 self
80 }
81
82 pub fn build(self) -> RefreshTokenSource {
84 RefreshTokenSource::new(
85 self.name,
86 self.token_refresher,
87 self.refresh_retry_delay,
88 self.refresh_threshold,
89 self.refresh_timeout,
90 )
91 }
92}
93
94pub struct RefreshTokenSource {
99 watch_rx: watch::Receiver<Option<Result<String, TokenSourceError>>>,
101 #[allow(unused)]
104 task_handle: RefreshingTokenSourceTaskHandle,
105}
106
107impl RefreshTokenSource {
108 pub fn builder<T: TokenRefresher>(
110 name: impl Into<String>,
111 token_refresher: T,
112 ) -> RefreshTokenSourceBuilder<T> {
113 RefreshTokenSourceBuilder::new(name.into(), token_refresher)
114 }
115
116 pub fn new(
126 name: String,
127 token_refresher: impl TokenRefresher,
128 refresh_retry_delay: Duration,
129 refresh_threshold: Duration,
130 min_token_lifetime: Duration,
131 ) -> Self {
132 let (watch_tx, watch_rx) = tokio::sync::watch::channel(None);
133 let inner = RefreshTokenSourceTask {
134 name,
135 watch_tx,
136 refresh_retry_delay,
137 refresh_threshold,
138 min_token_lifetime,
139 token_refresher: Box::new(token_refresher),
140 };
141
142 let task_handle = inner.run();
143
144 Self {
145 watch_rx,
146 task_handle,
147 }
148 }
149}
150
151#[async_trait]
152impl TokenSource for RefreshTokenSource {
153 fn watch(&self) -> watch::Receiver<Option<Result<String, TokenSourceError>>> {
154 self.watch_rx.clone()
155 }
156}
157
158#[derive(Clone, Debug)]
160pub struct TokenWithExpiry {
161 pub token: String,
163 pub expires_at: Instant,
165}
166
167struct RefreshingTokenSourceTaskHandle {
173 handle: JoinHandle<()>,
174}
175
176impl Drop for RefreshingTokenSourceTaskHandle {
177 fn drop(&mut self) {
178 self.handle.abort();
179 }
180}
181
182struct RefreshTokenSourceTask {
183 name: String,
184 watch_tx: watch::Sender<Option<Result<String, TokenSourceError>>>,
185 refresh_retry_delay: Duration,
186 refresh_threshold: Duration,
187 #[allow(clippy::type_complexity)]
188 token_refresher: Box<dyn TokenRefresher>,
189 min_token_lifetime: Duration,
190}
191
192impl RefreshTokenSourceTask {
193 fn run(self) -> RefreshingTokenSourceTaskHandle {
194 let handle = tokio::spawn(async move {
195 let mut fail_count = 0;
196 let mut current_token: Option<TokenWithExpiry> = None;
197 loop {
198 let token_expiry = match current_token {
200 Some(ref token) => token.expires_at,
201 _ => Instant::now(),
203 };
204
205 let refresh_deadline = token_expiry
206 .checked_sub(self.refresh_threshold)
207 .unwrap_or_else(Instant::now);
208
209 tokio::time::sleep_until(refresh_deadline.into()).await;
210
211 let new_token = self.token_refresher.refresh().await;
213
214 match new_token {
215 Ok(token) => {
217 let token_ttl_secs = token
218 .expires_at
219 .saturating_duration_since(Instant::now())
220 .as_secs();
221
222 if token.expires_at <= Instant::now() + self.min_token_lifetime {
224 tracing::error!(
225 name = %self.name,
226 token_ttl_secs,
227 "Refreshed token is already expired or too close to expiry, ignoring"
228 );
229 tokio::time::sleep(self.refresh_retry_delay).await;
233 continue;
234 }
235
236 fail_count = 0;
237
238 tracing::info!(
239 name = %self.name,
240 token_ttl_secs,
241 "Refreshed token"
242 );
243
244 current_token = Some(token.clone());
245 self.watch_tx.send_replace(Some(Ok(token.token)));
246 }
247 Err(e) => {
249 fail_count += 1;
250
251 tracing::error!(
252 name = %self.name,
253 ttl_secs = token_expiry.saturating_duration_since(Instant::now()).as_secs(),
254 retry_secs = self.refresh_retry_delay.as_secs(),
255 fail_count,
256 error = %e,
257 "Failed to refresh token"
258 );
259
260 if token_expiry <= Instant::now() + self.min_token_lifetime {
263 current_token = None;
264 self.watch_tx.send_replace(Some(Err(e)));
265 continue;
266 }
267
268 tokio::time::sleep(self.refresh_retry_delay).await;
269 continue;
270 }
271 }
272 }
273 });
274
275 RefreshingTokenSourceTaskHandle { handle }
276 }
277}
278
279#[async_trait]
286pub trait TokenRefresher: Send + Sync + 'static {
287 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError>;
289}
290
291#[async_trait]
293impl<AsyncFn, FnFuture> TokenRefresher for AsyncFn
294where
295 AsyncFn: Fn() -> FnFuture + Send + Sync + 'static,
296 FnFuture: Future<Output = Result<TokenWithExpiry, TokenSourceError>> + Send,
297{
298 async fn refresh(&self) -> Result<TokenWithExpiry, TokenSourceError> {
299 (self)().await
300 }
301}