1use tokio::sync::Mutex;
2
3use crate::refresher::Refresher;
4use crate::{ServiceToken, Token};
5
6#[derive(Debug, thiserror::Error)]
11pub(crate) enum AutoRefreshError {
12 #[error("No token found")]
14 NotFound,
15 #[error("Token has expired")]
17 Expired,
18 #[error("Auth error: {0}")]
20 Auth(#[from] crate::AuthError),
21}
22
23impl From<AutoRefreshError> for crate::AuthError {
24 fn from(err: AutoRefreshError) -> Self {
25 match err {
26 AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
27 AutoRefreshError::Expired => crate::AuthError::TokenExpired,
28 AutoRefreshError::Auth(e) => e,
29 }
30 }
31}
32
33#[cfg_attr(doc, aquamarine::aquamarine)]
116pub(crate) struct AutoRefresh<R> {
117 refresher: R,
118 state: Mutex<State>,
119}
120
121struct State {
122 token: Option<Token>,
123 refresh_in_progress: bool,
124}
125
126impl<R> AutoRefresh<R> {
127 pub(crate) fn new(refresher: R) -> Self {
133 Self {
134 refresher,
135 state: Mutex::new(State {
136 token: None,
137 refresh_in_progress: false,
138 }),
139 }
140 }
141
142 pub(crate) fn with_token(refresher: R, token: Token) -> Self {
147 Self {
148 refresher,
149 state: Mutex::new(State {
150 token: Some(token),
151 refresh_in_progress: false,
152 }),
153 }
154 }
155}
156
157impl<R: Refresher> AutoRefresh<R> {
158 pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
160 let mut state = self.state.lock().await;
161
162 if state.token.is_none() {
164 let Some(credential) = self.refresher.try_credential(None) else {
165 return Err(AutoRefreshError::NotFound);
166 };
167 state.refresh_in_progress = true;
168 match self.refresher.refresh(&credential).await {
169 Ok(new_token) => {
170 self.refresher.save(&new_token);
171 let service_token = ServiceToken::new(new_token.access_token().clone());
172 state.token = Some(new_token);
173 state.refresh_in_progress = false;
174 return Ok(service_token);
175 }
176 Err(err) => {
177 state.refresh_in_progress = false;
178 return Err(AutoRefreshError::Auth(err));
179 }
180 }
181 }
182
183 let needs_refresh = state.token.as_ref().is_some_and(|t| t.is_expired());
184 if !needs_refresh {
185 let token = state.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
187 return Ok(ServiceToken::new(token.access_token().clone()));
188 }
189
190 if state.refresh_in_progress {
192 let token = state.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
193 if token.is_usable() {
194 return Ok(ServiceToken::new(token.access_token().clone()));
195 }
196 return Err(AutoRefreshError::Expired);
205 }
206
207 let credential = self.refresher.try_credential(state.token.as_mut());
209
210 let Some(credential) = credential else {
211 let token = state.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
213 if token.is_usable() {
214 return Ok(ServiceToken::new(token.access_token().clone()));
215 }
216 return Err(AutoRefreshError::Expired);
217 };
218
219 state.refresh_in_progress = true;
220
221 let is_usable = state.token.as_ref().is_some_and(|t| t.is_usable());
223
224 if is_usable {
225 let current_service_token = ServiceToken::new(
228 state
229 .token
230 .as_ref()
231 .ok_or(AutoRefreshError::NotFound)?
232 .access_token()
233 .clone(),
234 );
235 drop(state);
236
237 match self.refresher.refresh(&credential).await {
238 Ok(new_token) => {
239 self.refresher.save(&new_token);
240 let mut state = self.state.lock().await;
241 state.token = Some(new_token);
242 state.refresh_in_progress = false;
243 }
244 Err(err) => {
245 tracing::warn!(%err, "token refresh failed (token still usable)");
246 let mut state = self.state.lock().await;
247 if let Some(token) = state.token.as_mut() {
248 self.refresher.restore(token, credential);
249 }
250 state.refresh_in_progress = false;
251 }
252 }
253
254 Ok(current_service_token)
255 } else {
256 match self.refresher.refresh(&credential).await {
258 Ok(new_token) => {
259 self.refresher.save(&new_token);
260 let service_token = ServiceToken::new(new_token.access_token().clone());
261 state.token = Some(new_token);
262 state.refresh_in_progress = false;
263 Ok(service_token)
264 }
265 Err(err) => {
266 tracing::warn!(%err, "token refresh failed");
267 if let Some(token) = state.token.as_mut() {
268 self.refresher.restore(token, credential);
269 }
270 state.refresh_in_progress = false;
271 Err(AutoRefreshError::Expired)
272 }
273 }
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::oauth_refresher::OAuthRefresher;
282 use crate::SecretToken;
283 use mocktail::prelude::*;
284 use stack_profile::ProfileStore;
285 use std::sync::Arc;
286 use std::time::{SystemTime, UNIX_EPOCH};
287
288 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
289 let now = SystemTime::now()
290 .duration_since(UNIX_EPOCH)
291 .unwrap()
292 .as_secs();
293
294 Token {
295 access_token: SecretToken::new(access),
296 token_type: "Bearer".to_string(),
297 expires_at: now + expires_in,
298 refresh_token: if refresh {
299 Some(SecretToken::new("test-refresh-token"))
300 } else {
301 None
302 },
303 region: None,
304 client_id: None,
305 device_instance_id: None,
306 }
307 }
308
309 fn refresh_response_json(access: &str) -> serde_json::Value {
310 serde_json::json!({
311 "access_token": access,
312 "token_type": "Bearer",
313 "expires_in": 3600,
314 "refresh_token": "new-refresh-token"
315 })
316 }
317
318 fn error_json(error: &str) -> serde_json::Value {
319 serde_json::json!({
320 "error": error,
321 "error_description": format!("{error} occurred")
322 })
323 }
324
325 async fn start_server(mocks: MockSet) -> MockServer {
326 let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
327 server.start().await.unwrap();
328 server
329 }
330
331 fn auto_refresh_with_token(
332 dir: &tempfile::TempDir,
333 server: &MockServer,
334 token: Token,
335 ) -> AutoRefresh<OAuthRefresher> {
336 let store = ProfileStore::new(dir.path());
337 store.save_profile(&token).unwrap();
338 let refresher = OAuthRefresher::new(
339 Some(store),
340 server.url(""),
341 "cli",
342 "ap-southeast-2.aws",
343 None,
344 );
345 AutoRefresh::with_token(refresher, token)
346 }
347
348 #[tokio::test]
351 async fn test_returns_cached_token() {
352 let dir = tempfile::tempdir().unwrap();
353 let server = start_server(MockSet::new()).await;
354 let strategy =
355 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
356
357 let token = strategy.get_token().await.unwrap();
358
359 assert_eq!(token.as_str(), "my-access-token");
360 }
361
362 #[tokio::test]
363 async fn test_returns_not_found_when_no_token_and_oauth() {
364 let server = start_server(MockSet::new()).await;
365 let store = ProfileStore::new("/tmp/nonexistent");
366 let refresher = OAuthRefresher::new(
367 Some(store),
368 server.url(""),
369 "cli",
370 "ap-southeast-2.aws",
371 None,
372 );
373 let strategy = AutoRefresh::new(refresher);
374
375 let err = strategy.get_token().await.unwrap_err();
376
377 assert!(matches!(err, AutoRefreshError::NotFound));
378 }
379
380 #[tokio::test]
381 async fn test_caches_token_across_calls() {
382 let dir = tempfile::tempdir().unwrap();
383 let server = start_server(MockSet::new()).await;
384 let strategy =
385 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
386
387 let token1 = strategy.get_token().await.unwrap();
388 assert_eq!(token1.as_str(), "my-access-token");
389
390 std::fs::remove_file(dir.path().join("auth.json")).unwrap();
392
393 let token2 = strategy.get_token().await.unwrap();
394 assert_eq!(token2.as_str(), "my-access-token");
395 }
396
397 #[tokio::test]
400 async fn test_expired_token_without_refresh_token_returns_expired() {
401 let dir = tempfile::tempdir().unwrap();
402 let server = start_server(MockSet::new()).await;
403 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
404
405 let err = strategy.get_token().await.unwrap_err();
406
407 assert!(matches!(err, AutoRefreshError::Expired));
408 }
409
410 #[tokio::test]
413 async fn test_refreshes_expiring_token() {
414 let mut mocks = MockSet::new();
415 mocks.mock(|when, then| {
416 when.post().path("/oauth/token");
417 then.json(refresh_response_json("refreshed-token"));
418 });
419 let server = start_server(mocks).await;
420 let dir = tempfile::tempdir().unwrap();
421 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
422
423 let token = strategy.get_token().await.unwrap();
424
425 assert_eq!(token.as_str(), "refreshed-token");
426 }
427
428 #[tokio::test]
429 async fn test_refresh_persists_new_token_to_disk() {
430 let mut mocks = MockSet::new();
431 mocks.mock(|when, then| {
432 when.post().path("/oauth/token");
433 then.json(refresh_response_json("refreshed-token"));
434 });
435 let server = start_server(mocks).await;
436 let dir = tempfile::tempdir().unwrap();
437 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
438
439 let _ = strategy.get_token().await.unwrap();
440
441 let store = ProfileStore::new(dir.path());
443 let on_disk: Token = store.load_profile().unwrap();
444 assert_eq!(on_disk.access_token().as_str(), "refreshed-token");
445 }
446
447 #[tokio::test]
448 async fn test_refresh_failure_returns_expired_when_token_is_expired() {
449 let mut mocks = MockSet::new();
450 mocks.mock(|when, then| {
451 when.post().path("/oauth/token");
452 then.bad_request().json(error_json("invalid_grant"));
453 });
454 let server = start_server(mocks).await;
455 let dir = tempfile::tempdir().unwrap();
456 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
457
458 let err = strategy.get_token().await.unwrap_err();
459
460 assert!(matches!(err, AutoRefreshError::Expired));
461 }
462
463 #[tokio::test]
464 async fn test_does_not_refresh_fresh_token() {
465 let mut mocks = MockSet::new();
467 mocks.mock(|when, then| {
468 when.post().path("/oauth/token");
469 then.internal_server_error()
470 .json(error_json("should_not_be_called"));
471 });
472 let server = start_server(mocks).await;
473 let dir = tempfile::tempdir().unwrap();
474 let strategy =
475 auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
476
477 let token = strategy.get_token().await.unwrap();
478
479 assert_eq!(token.as_str(), "fresh-token");
480 }
481
482 #[tokio::test]
485 async fn test_refresh_token_is_taken_preventing_second_refresh() {
486 let mut mocks = MockSet::new();
487 mocks.mock(|when, then| {
488 when.post().path("/oauth/token");
489 then.json(refresh_response_json("refreshed-token"));
490 });
491 let server = start_server(mocks).await;
492 let dir = tempfile::tempdir().unwrap();
493 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
494
495 let token = strategy.get_token().await.unwrap();
497 assert_eq!(token.as_str(), "refreshed-token");
498
499 server.mocks().clear();
501 server.mocks().mock(|when, then| {
502 when.post().path("/oauth/token");
503 then.bad_request().json(error_json("should_not_be_called"));
504 });
505
506 let token = strategy.get_token().await.unwrap();
509 assert_eq!(token.as_str(), "refreshed-token");
510 }
511
512 #[tokio::test]
513 async fn test_failed_refresh_restores_refresh_token_for_retry() {
514 let mut mocks = MockSet::new();
515 mocks.mock(|when, then| {
516 when.post().path("/oauth/token");
517 then.bad_request().json(error_json("invalid_grant"));
518 });
519 let server = start_server(mocks).await;
520 let dir = tempfile::tempdir().unwrap();
521 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
522
523 let err = strategy.get_token().await.unwrap_err();
525 assert!(matches!(err, AutoRefreshError::Expired));
526
527 let state = strategy.state.lock().await;
529 assert!(state.token.is_some());
530 assert!(state.token.as_ref().unwrap().refresh_token().is_some());
531 drop(state);
532
533 server.mocks().clear();
535 server.mocks().mock(|when, then| {
536 when.post().path("/oauth/token");
537 then.json(refresh_response_json("refreshed-token"));
538 });
539
540 let token = strategy.get_token().await.unwrap();
542 assert_eq!(token.as_str(), "refreshed-token");
543 }
544
545 #[tokio::test]
546 async fn test_access_token_remains_after_refresh_token_is_taken() {
547 let mut mocks = MockSet::new();
548 mocks.mock(|when, then| {
549 when.post().path("/oauth/token");
550 then.bad_request().json(error_json("server_error"));
551 });
552 let server = start_server(mocks).await;
553 let dir = tempfile::tempdir().unwrap();
554 let strategy = auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
557
558 let token = strategy.get_token().await.unwrap();
561 assert_eq!(token.as_str(), "still-usable");
562
563 let state = strategy.state.lock().await;
565 assert!(state.token.is_some());
566 assert_eq!(
567 state.token.as_ref().unwrap().access_token().as_str(),
568 "still-usable"
569 );
570 assert!(
571 state.token.as_ref().unwrap().refresh_token().is_some(),
572 "refresh token should be restored after failed refresh"
573 );
574 }
575
576 #[tokio::test]
577 async fn test_failed_refresh_of_usable_token_can_be_retried() {
578 let mut mocks = MockSet::new();
579 mocks.mock(|when, then| {
580 when.post().path("/oauth/token");
581 then.bad_request().json(error_json("server_error"));
582 });
583 let server = start_server(mocks).await;
584 let dir = tempfile::tempdir().unwrap();
585 let strategy = auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
587
588 let token = strategy.get_token().await.unwrap();
590 assert_eq!(token.as_str(), "still-usable");
591
592 server.mocks().clear();
594 server.mocks().mock(|when, then| {
595 when.post().path("/oauth/token");
596 then.json(refresh_response_json("refreshed-token"));
597 });
598
599 let token = strategy.get_token().await.unwrap();
601 assert!(
602 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
603 "expected old or refreshed token, got: {}",
604 token.as_str()
605 );
606
607 let state = strategy.state.lock().await;
609 assert_eq!(
610 state.token.as_ref().unwrap().access_token().as_str(),
611 "refreshed-token"
612 );
613 }
614
615 #[tokio::test]
616 async fn test_multiple_sequential_calls_only_refresh_once() {
617 let mut mocks = MockSet::new();
618 mocks.mock(|when, then| {
619 when.post().path("/oauth/token");
620 then.json(refresh_response_json("refreshed-once"));
621 });
622 let server = start_server(mocks).await;
623 let dir = tempfile::tempdir().unwrap();
624 let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
625
626 let token = strategy.get_token().await.unwrap();
628 assert_eq!(token.as_str(), "refreshed-once");
629
630 server.mocks().clear();
632 server.mocks().mock(|when, then| {
633 when.post().path("/oauth/token");
634 then.json(refresh_response_json("refreshed-twice"));
635 });
636
637 for _ in 0..4 {
639 let token = strategy.get_token().await.unwrap();
640 assert_eq!(
641 token.as_str(),
642 "refreshed-once",
643 "should return cached refreshed token, not trigger another refresh"
644 );
645 }
646 }
647
648 #[tokio::test]
651 async fn test_concurrent_access_with_expiring_but_usable_token() {
652 let mut mocks = MockSet::new();
653 mocks.mock(|when, then| {
654 when.post().path("/oauth/token");
655 then.json(refresh_response_json("refreshed-token"));
656 });
657 let server = start_server(mocks).await;
658 let dir = tempfile::tempdir().unwrap();
659 let strategy = Arc::new(auto_refresh_with_token(
660 &dir,
661 &server,
662 make_token("still-usable", 30, true),
663 ));
664
665 let s1 = Arc::clone(&strategy);
666 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
667
668 let s2 = Arc::clone(&strategy);
669 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
670
671 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
672 let token_a = result_a.unwrap();
673 let token_b = result_b.unwrap();
674
675 assert!(
676 token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
677 "unexpected token_a: {}",
678 token_a.as_str()
679 );
680 assert!(
681 token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
682 "unexpected token_b: {}",
683 token_b.as_str()
684 );
685 }
686
687 #[tokio::test]
688 async fn test_concurrent_access_with_fully_expired_token() {
689 let mut mocks = MockSet::new();
690 mocks.mock(|when, then| {
691 when.post().path("/oauth/token");
692 then.json(refresh_response_json("refreshed-token"));
693 });
694 let server = start_server(mocks).await;
695 let dir = tempfile::tempdir().unwrap();
696 let strategy = Arc::new(auto_refresh_with_token(
697 &dir,
698 &server,
699 make_token("expired-token", 0, true),
700 ));
701
702 let s1 = Arc::clone(&strategy);
703 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
704
705 let s2 = Arc::clone(&strategy);
706 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
707
708 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
709 let token_a = result_a.unwrap();
710 let token_b = result_b.unwrap();
711
712 assert_eq!(token_a.as_str(), "refreshed-token");
713 assert_eq!(token_b.as_str(), "refreshed-token");
714 }
715}
716
717#[cfg(test)]
718mod stress_tests {
719 use super::*;
720 use crate::oauth_refresher::OAuthRefresher;
721 use crate::SecretToken;
722 use stack_profile::ProfileStore;
723 use std::sync::atomic::{AtomicUsize, Ordering};
724 use std::sync::Arc;
725 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
726
727 #[derive(Clone)]
729 struct CountingState {
730 total: Arc<AtomicUsize>,
731 current: Arc<AtomicUsize>,
732 peak: Arc<AtomicUsize>,
733 }
734
735 impl CountingState {
736 fn new() -> Self {
737 Self {
738 total: Arc::new(AtomicUsize::new(0)),
739 current: Arc::new(AtomicUsize::new(0)),
740 peak: Arc::new(AtomicUsize::new(0)),
741 }
742 }
743
744 fn enter(&self) {
745 self.total.fetch_add(1, Ordering::SeqCst);
746 let prev = self.current.fetch_add(1, Ordering::SeqCst);
747 self.peak.fetch_max(prev + 1, Ordering::SeqCst);
748 }
749
750 fn exit(&self) {
751 self.current.fetch_sub(1, Ordering::SeqCst);
752 }
753
754 fn peak(&self) -> usize {
755 self.peak.load(Ordering::SeqCst)
756 }
757
758 fn total(&self) -> usize {
759 self.total.load(Ordering::SeqCst)
760 }
761 }
762
763 #[derive(Clone)]
764 struct DelayedRefreshState {
765 counting: CountingState,
766 delay: Duration,
767 }
768
769 async fn delayed_refresh_handler(
770 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
771 ) -> axum::Json<serde_json::Value> {
772 state.counting.enter();
773 tokio::time::sleep(state.delay).await;
774 state.counting.exit();
775 axum::Json(serde_json::json!({
776 "access_token": "refreshed-token",
777 "token_type": "Bearer",
778 "expires_in": 3600,
779 "refresh_token": "new-refresh-token"
780 }))
781 }
782
783 async fn delayed_error_handler(
784 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
785 ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
786 state.counting.enter();
787 tokio::time::sleep(state.delay).await;
788 state.counting.exit();
789 (
790 axum::http::StatusCode::BAD_REQUEST,
791 axum::Json(serde_json::json!({
792 "error": "invalid_grant",
793 "error_description": "invalid_grant occurred"
794 })),
795 )
796 }
797
798 async fn start_axum_server<H, T>(
799 handler: H,
800 state: DelayedRefreshState,
801 ) -> (url::Url, CountingState)
802 where
803 H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
804 T: 'static,
805 {
806 let counting = state.counting.clone();
807 let app = axum::Router::new()
808 .route("/oauth/token", axum::routing::post(handler))
809 .with_state(state);
810 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
811 let addr = listener.local_addr().unwrap();
812 tokio::spawn(async move {
813 axum::serve(listener, app).await.unwrap();
814 });
815 let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
816 (base_url, counting)
817 }
818
819 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
820 let now = SystemTime::now()
821 .duration_since(UNIX_EPOCH)
822 .unwrap()
823 .as_secs();
824
825 Token {
826 access_token: SecretToken::new(access),
827 token_type: "Bearer".to_string(),
828 expires_at: now + expires_in,
829 refresh_token: if refresh {
830 Some(SecretToken::new("test-refresh-token"))
831 } else {
832 None
833 },
834 region: None,
835 client_id: None,
836 device_instance_id: None,
837 }
838 }
839
840 fn auto_refresh_with_token(
841 dir: &tempfile::TempDir,
842 base_url: &url::Url,
843 token: Token,
844 ) -> AutoRefresh<OAuthRefresher> {
845 let store = ProfileStore::new(dir.path());
846 store.save_profile(&token).unwrap();
847 let refresher = OAuthRefresher::new(
848 Some(store),
849 base_url.clone(),
850 "cli",
851 "ap-southeast-2.aws",
852 None,
853 );
854 AutoRefresh::with_token(refresher, token)
855 }
856
857 const CONCURRENCY: usize = 50;
858
859 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
860 async fn test_concurrent_fresh_token_no_contention() {
861 let counting = CountingState::new();
862 let state = DelayedRefreshState {
863 counting: counting.clone(),
864 delay: Duration::from_millis(500),
865 };
866 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
867 let dir = tempfile::tempdir().unwrap();
868 let strategy = Arc::new(auto_refresh_with_token(
869 &dir,
870 &base_url,
871 make_token("fresh-token", 3600, true),
872 ));
873
874 let start = Instant::now();
875 let mut handles = Vec::with_capacity(CONCURRENCY);
876 for _ in 0..CONCURRENCY {
877 let s = Arc::clone(&strategy);
878 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
879 }
880
881 let results: Vec<_> = {
882 let mut results = Vec::with_capacity(handles.len());
883 for handle in handles {
884 results.push(handle.await.unwrap());
885 }
886 results
887 };
888 let elapsed = start.elapsed();
889
890 for token in &results {
891 assert_eq!(token.as_str(), "fresh-token");
892 }
893
894 assert!(
895 elapsed < Duration::from_millis(200),
896 "expected < 200ms for fresh tokens, got {:?}",
897 elapsed
898 );
899 assert_eq!(stats.total(), 0, "no refresh requests should be made");
900 }
901
902 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
903 async fn test_concurrent_expiring_token_non_blocking_reads() {
904 let counting = CountingState::new();
905 let state = DelayedRefreshState {
906 counting: counting.clone(),
907 delay: Duration::from_millis(500),
908 };
909 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
910 let dir = tempfile::tempdir().unwrap();
911 let strategy = Arc::new(auto_refresh_with_token(
912 &dir,
913 &base_url,
914 make_token("still-usable", 30, true),
915 ));
916
917 let start = Instant::now();
918 let mut handles = Vec::with_capacity(CONCURRENCY);
919 for _ in 0..CONCURRENCY {
920 let s = Arc::clone(&strategy);
921 handles.push(tokio::spawn(async move {
922 let call_start = Instant::now();
923 let token = s.get_token().await.unwrap();
924 (token, call_start.elapsed())
925 }));
926 }
927
928 let results: Vec<_> = {
929 let mut results = Vec::with_capacity(handles.len());
930 for handle in handles {
931 results.push(handle.await.unwrap());
932 }
933 results
934 };
935 let elapsed = start.elapsed();
936
937 for (token, _) in &results {
938 assert!(
939 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
940 "unexpected token: {}",
941 token.as_str()
942 );
943 }
944
945 let fast_callers = results
946 .iter()
947 .filter(|(_, dur)| *dur < Duration::from_millis(100))
948 .count();
949 assert!(
950 fast_callers >= CONCURRENCY - 1,
951 "expected at least {} fast callers, got {} (total elapsed: {:?})",
952 CONCURRENCY - 1,
953 fast_callers,
954 elapsed
955 );
956
957 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
958 assert_eq!(stats.total(), 1, "total refresh requests");
959 }
960
961 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
962 async fn test_concurrent_expired_token_blocks_until_refresh() {
963 let refresh_delay = Duration::from_millis(200);
964 let counting = CountingState::new();
965 let state = DelayedRefreshState {
966 counting: counting.clone(),
967 delay: refresh_delay,
968 };
969 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
970 let dir = tempfile::tempdir().unwrap();
971 let strategy = Arc::new(auto_refresh_with_token(
972 &dir,
973 &base_url,
974 make_token("expired-token", 0, true),
975 ));
976
977 let start = Instant::now();
978 let mut handles = Vec::with_capacity(CONCURRENCY);
979 for _ in 0..CONCURRENCY {
980 let s = Arc::clone(&strategy);
981 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
982 }
983
984 let results: Vec<_> = {
985 let mut results = Vec::with_capacity(handles.len());
986 for handle in handles {
987 results.push(handle.await.unwrap());
988 }
989 results
990 };
991 let elapsed = start.elapsed();
992
993 for token in &results {
994 assert_eq!(token.as_str(), "refreshed-token");
995 }
996
997 assert!(
998 elapsed < refresh_delay + Duration::from_millis(200),
999 "expected < {:?} for blocked callers, got {:?}",
1000 refresh_delay + Duration::from_millis(200),
1001 elapsed
1002 );
1003
1004 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1005 assert_eq!(stats.total(), 1, "total refresh requests");
1006 }
1007
1008 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1009 async fn test_concurrent_expired_token_refresh_failure_recovers() {
1010 let counting = CountingState::new();
1011 let state = DelayedRefreshState {
1012 counting: counting.clone(),
1013 delay: Duration::from_millis(10),
1014 };
1015 let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1016 let dir = tempfile::tempdir().unwrap();
1017 let strategy = Arc::new(auto_refresh_with_token(
1018 &dir,
1019 &base_url,
1020 make_token("expired-token", 0, true),
1021 ));
1022
1023 let mut handles = Vec::with_capacity(CONCURRENCY);
1024 for _ in 0..CONCURRENCY {
1025 let s = Arc::clone(&strategy);
1026 handles.push(tokio::spawn(async move { s.get_token().await }));
1027 }
1028
1029 let results: Vec<_> = {
1030 let mut results = Vec::with_capacity(handles.len());
1031 for handle in handles {
1032 results.push(handle.await.unwrap());
1033 }
1034 results
1035 };
1036
1037 for result in &results {
1038 assert!(result.is_err(), "expected Expired error, got Ok");
1039 assert!(matches!(
1040 result.as_ref().unwrap_err(),
1041 AutoRefreshError::Expired
1042 ));
1043 }
1044
1045 let state = strategy.state.lock().await;
1046 assert!(
1047 state.token.as_ref().unwrap().refresh_token().is_some(),
1048 "refresh token should be restored after failed refresh"
1049 );
1050 drop(state);
1051
1052 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1053 assert!(
1054 stats.total() >= 1,
1055 "at least one refresh attempt should be made"
1056 );
1057 }
1058
1059 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1060 async fn test_concurrent_refresh_failure_then_retry() {
1061 let counting1 = CountingState::new();
1063 let state1 = DelayedRefreshState {
1064 counting: counting1.clone(),
1065 delay: Duration::from_millis(50),
1066 };
1067 let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1068 let dir = tempfile::tempdir().unwrap();
1069 let strategy = Arc::new(auto_refresh_with_token(
1070 &dir,
1071 &base_url,
1072 make_token("expired-token", 0, true),
1073 ));
1074
1075 let mut handles = Vec::with_capacity(CONCURRENCY);
1076 for _ in 0..CONCURRENCY {
1077 let s = Arc::clone(&strategy);
1078 handles.push(tokio::spawn(async move { s.get_token().await }));
1079 }
1080
1081 let results: Vec<_> = {
1082 let mut results = Vec::with_capacity(handles.len());
1083 for handle in handles {
1084 results.push(handle.await.unwrap());
1085 }
1086 results
1087 };
1088
1089 for result in &results {
1090 assert!(
1091 result.is_err(),
1092 "first wave: expected Expired, got Ok({})",
1093 result.as_ref().unwrap().as_str()
1094 );
1095 }
1096
1097 let counting2 = CountingState::new();
1099 let state2 = DelayedRefreshState {
1100 counting: counting2.clone(),
1101 delay: Duration::from_millis(50),
1102 };
1103 let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1104
1105 let strategy2 = Arc::new(auto_refresh_with_token(
1106 &dir,
1107 &base_url2,
1108 make_token("expired-token", 0, true),
1109 ));
1110
1111 let mut handles = Vec::with_capacity(CONCURRENCY);
1112 for _ in 0..CONCURRENCY {
1113 let s = Arc::clone(&strategy2);
1114 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1115 }
1116
1117 let results: Vec<_> = {
1118 let mut results = Vec::with_capacity(handles.len());
1119 for handle in handles {
1120 results.push(handle.await.unwrap());
1121 }
1122 results
1123 };
1124
1125 for token in &results {
1126 assert_eq!(token.as_str(), "refreshed-token");
1127 }
1128
1129 assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1130 }
1131}