1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::{ServiceToken, Token};
7
8#[derive(Debug, thiserror::Error)]
13pub(crate) enum AutoRefreshError {
14 #[error("No token found")]
16 NotFound,
17 #[error("Token has expired")]
19 Expired,
20 #[error("Auth error: {0}")]
22 Auth(#[from] crate::AuthError),
23}
24
25impl From<AutoRefreshError> for crate::AuthError {
26 fn from(err: AutoRefreshError) -> Self {
27 match err {
28 AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
29 AutoRefreshError::Expired => crate::AuthError::TokenExpired,
30 AutoRefreshError::Auth(e) => e,
31 }
32 }
33}
34
35pub(crate) struct AutoRefresh<R> {
41 refresher: R,
42 state: Mutex<State>,
43 refresh_in_progress: AtomicBool,
49 refresh_notify: Notify,
50}
51
52struct State {
53 token: Option<Token>,
54}
55
56struct CancelGuard<'a> {
62 in_progress: &'a AtomicBool,
63 notify: &'a Notify,
64 defused: bool,
65}
66
67impl Drop for CancelGuard<'_> {
68 fn drop(&mut self) {
69 if !self.defused {
70 self.in_progress.store(false, Ordering::Release);
71 self.notify.notify_waiters();
72 }
73 }
74}
75
76impl CancelGuard<'_> {
77 fn defuse(&mut self) {
78 self.defused = true;
79 }
80}
81
82impl State {
83 fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
84 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
85 Ok(ServiceToken::new(token.access_token().clone()))
86 }
87
88 fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
89 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
90 if token.is_usable() {
91 Ok(ServiceToken::new(token.access_token().clone()))
92 } else {
93 Err(AutoRefreshError::Expired)
94 }
95 }
96}
97
98impl<R> AutoRefresh<R> {
99 pub(crate) fn new(refresher: R) -> Self {
105 Self {
106 refresher,
107 state: Mutex::new(State { token: None }),
108 refresh_in_progress: AtomicBool::new(false),
109 refresh_notify: Notify::new(),
110 }
111 }
112
113 pub(crate) fn with_token(refresher: R, token: Token) -> Self {
118 Self {
119 refresher,
120 state: Mutex::new(State { token: Some(token) }),
121 refresh_in_progress: AtomicBool::new(false),
122 refresh_notify: Notify::new(),
123 }
124 }
125}
126
127impl<R: Refresher> AutoRefresh<R> {
128 pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
130 let mut state = self.state.lock().await;
131
132 if state.token.is_none() {
133 return self.initial_auth(&mut state).await;
134 }
135
136 if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
137 return state.service_token();
138 }
139
140 if self.refresh_in_progress.load(Ordering::Acquire) {
141 return self.wait_for_in_flight_refresh(state).await;
142 }
143
144 let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
145 return state.require_usable_token();
146 };
147
148 self.refresh_in_progress.store(true, Ordering::Release);
149
150 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
151 self.refresh_non_blocking(state, credential).await
152 } else {
153 self.refresh_blocking(&mut state, credential).await
154 }
155 }
156
157 async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
161 let Some(credential) = self.refresher.try_credential(None) else {
162 return Err(AutoRefreshError::NotFound);
163 };
164 self.refresh_in_progress.store(true, Ordering::Release);
165 let mut guard = CancelGuard {
166 in_progress: &self.refresh_in_progress,
167 notify: &self.refresh_notify,
168 defused: false,
169 };
170 match self.refresher.refresh(&credential).await {
171 Ok(new_token) => {
172 guard.defuse();
173 self.refresher.save(&new_token);
174 let service_token = ServiceToken::new(new_token.access_token().clone());
175 state.token = Some(new_token);
176 self.refresh_in_progress.store(false, Ordering::Release);
177 Ok(service_token)
178 }
179 Err(err) => {
180 guard.defuse();
181 self.refresh_in_progress.store(false, Ordering::Release);
182 Err(AutoRefreshError::Auth(err))
183 }
184 }
185 }
186
187 async fn wait_for_in_flight_refresh(
193 &self,
194 state: MutexGuard<'_, State>,
195 ) -> Result<ServiceToken, AutoRefreshError> {
196 if let Ok(token) = state.service_token() {
197 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
198 return Ok(token);
199 }
200 }
201 let notified = self.refresh_notify.notified();
204 drop(state);
205 notified.await;
206 let state = self.state.lock().await;
208 state.require_usable_token()
209 }
210
211 async fn refresh_non_blocking(
221 &self,
222 state: MutexGuard<'_, State>,
223 credential: R::Credential,
224 ) -> Result<ServiceToken, AutoRefreshError> {
225 let current_service_token = state.service_token()?;
226 drop(state);
227
228 let mut guard = CancelGuard {
229 in_progress: &self.refresh_in_progress,
230 notify: &self.refresh_notify,
231 defused: false,
232 };
233
234 match self.refresher.refresh(&credential).await {
235 Ok(new_token) => {
236 guard.defuse();
237 self.refresher.save(&new_token);
238 let mut state = self.state.lock().await;
239 state.token = Some(new_token);
240 self.refresh_in_progress.store(false, Ordering::Release);
241 }
242 Err(err) => {
243 guard.defuse();
244 tracing::warn!(%err, "token refresh failed (token still usable)");
245 let mut state = self.state.lock().await;
246 if let Some(token) = state.token.as_mut() {
247 self.refresher.restore(token, credential);
248 }
249 self.refresh_in_progress.store(false, Ordering::Release);
250 }
251 }
252
253 self.refresh_notify.notify_waiters();
254 Ok(current_service_token)
255 }
256
257 async fn refresh_blocking(
266 &self,
267 state: &mut State,
268 credential: R::Credential,
269 ) -> Result<ServiceToken, AutoRefreshError> {
270 let mut guard = CancelGuard {
271 in_progress: &self.refresh_in_progress,
272 notify: &self.refresh_notify,
273 defused: false,
274 };
275 match self.refresher.refresh(&credential).await {
276 Ok(new_token) => {
277 guard.defuse();
278 self.refresher.save(&new_token);
279 let service_token = ServiceToken::new(new_token.access_token().clone());
280 state.token = Some(new_token);
281 self.refresh_in_progress.store(false, Ordering::Release);
282 Ok(service_token)
283 }
284 Err(err) => {
285 guard.defuse();
286 tracing::warn!(%err, "token refresh failed");
287 if let Some(token) = state.token.as_mut() {
288 self.refresher.restore(token, credential);
289 }
290 self.refresh_in_progress.store(false, Ordering::Release);
291 Err(AutoRefreshError::Expired)
292 }
293 }
294 }
295}
296
297#[cfg(test)]
298#[allow(clippy::unwrap_used)]
299mod tests {
300 use super::*;
301 use crate::oauth_refresher::OAuthRefresher;
302 use crate::SecretToken;
303 use mocktail::prelude::*;
304 use stack_profile::ProfileStore;
305 use std::sync::Arc;
306 use std::time::{SystemTime, UNIX_EPOCH};
307
308 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
309 let now = SystemTime::now()
310 .duration_since(UNIX_EPOCH)
311 .unwrap()
312 .as_secs();
313
314 Token {
315 access_token: SecretToken::new(access),
316 token_type: "Bearer".to_string(),
317 expires_at: now + expires_in,
318 refresh_token: if refresh {
319 Some(SecretToken::new("test-refresh-token"))
320 } else {
321 None
322 },
323 region: None,
324 client_id: None,
325 device_instance_id: None,
326 }
327 }
328
329 fn refresh_response_json(access: &str) -> serde_json::Value {
330 serde_json::json!({
331 "access_token": access,
332 "token_type": "Bearer",
333 "expires_in": 3600,
334 "refresh_token": "new-refresh-token"
335 })
336 }
337
338 fn error_json(error: &str) -> serde_json::Value {
339 serde_json::json!({
340 "error": error,
341 "error_description": format!("{error} occurred")
342 })
343 }
344
345 async fn start_server(mocks: MockSet) -> MockServer {
346 let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
347 server.start().await.unwrap();
348 server
349 }
350
351 fn auto_refresh_with_token(
352 dir: &tempfile::TempDir,
353 server: &MockServer,
354 token: Token,
355 ) -> AutoRefresh<OAuthRefresher> {
356 let store = ProfileStore::new(dir.path());
357 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
358 let ws_store = store.current_workspace_store().unwrap();
359 ws_store.save_profile(&token).unwrap();
360 let refresher = OAuthRefresher::new(
361 Some(ws_store),
362 server.url(""),
363 "cli",
364 "ap-southeast-2.aws",
365 None,
366 );
367 AutoRefresh::with_token(refresher, token)
368 }
369
370 mod given_no_cached_token {
371 use super::*;
372
373 #[tokio::test]
374 async fn returns_not_found_for_oauth() {
375 let server = start_server(MockSet::new()).await;
376 let store = ProfileStore::new("/tmp/nonexistent");
377 let refresher = OAuthRefresher::new(
378 Some(store),
379 server.url(""),
380 "cli",
381 "ap-southeast-2.aws",
382 None,
383 );
384 let strategy = AutoRefresh::new(refresher);
385
386 let err = strategy.get_token().await.unwrap_err();
387
388 assert!(
389 matches!(err, AutoRefreshError::NotFound),
390 "expected NotFound, got: {err:?}"
391 );
392 }
393 }
394
395 mod given_fresh_token {
396 use super::*;
397
398 #[tokio::test]
399 async fn returns_cached_token() {
400 let dir = tempfile::tempdir().unwrap();
401 let server = start_server(MockSet::new()).await;
402 let strategy =
403 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
404
405 let token = strategy.get_token().await.unwrap();
406
407 assert_eq!(
408 token.as_str(),
409 "my-access-token",
410 "should return the cached access token"
411 );
412 }
413
414 #[tokio::test]
415 async fn caches_across_calls() {
416 let dir = tempfile::tempdir().unwrap();
417 let server = start_server(MockSet::new()).await;
418 let strategy =
419 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
420
421 let token1 = strategy.get_token().await.unwrap();
422 assert_eq!(
423 token1.as_str(),
424 "my-access-token",
425 "first call should return the cached token"
426 );
427
428 std::fs::remove_file(
430 dir.path()
431 .join("workspaces")
432 .join("ZVATKW3VHMFG27DY")
433 .join("auth.json"),
434 )
435 .unwrap();
436
437 let token2 = strategy.get_token().await.unwrap();
438 assert_eq!(
439 token2.as_str(),
440 "my-access-token",
441 "second call should return the cached token even after file deletion"
442 );
443 }
444
445 #[tokio::test]
446 async fn does_not_trigger_refresh() {
447 let mut mocks = MockSet::new();
449 mocks.mock(|when, then| {
450 when.post().path("/oauth/token");
451 then.internal_server_error()
452 .json(error_json("should_not_be_called"));
453 });
454 let server = start_server(mocks).await;
455 let dir = tempfile::tempdir().unwrap();
456 let strategy =
457 auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
458
459 let token = strategy.get_token().await.unwrap();
460
461 assert_eq!(
462 token.as_str(),
463 "fresh-token",
464 "should return fresh token without triggering refresh"
465 );
466 }
467 }
468
469 mod given_fully_expired_token {
470 use super::*;
471
472 mod without_refresh_token {
473 use super::*;
474
475 #[tokio::test]
476 async fn returns_expired() {
477 let dir = tempfile::tempdir().unwrap();
478 let server = start_server(MockSet::new()).await;
479 let strategy =
480 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
481
482 let err = strategy.get_token().await.unwrap_err();
483
484 assert!(
485 matches!(err, AutoRefreshError::Expired),
486 "expected Expired, got: {err:?}"
487 );
488 }
489 }
490
491 mod with_refresh_token {
492 use super::*;
493
494 #[tokio::test]
495 async fn refreshes_and_returns_new_token() {
496 let mut mocks = MockSet::new();
497 mocks.mock(|when, then| {
498 when.post().path("/oauth/token");
499 then.json(refresh_response_json("refreshed-token"));
500 });
501 let server = start_server(mocks).await;
502 let dir = tempfile::tempdir().unwrap();
503 let strategy =
504 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
505
506 let token = strategy.get_token().await.unwrap();
507
508 assert_eq!(
509 token.as_str(),
510 "refreshed-token",
511 "should return the refreshed token"
512 );
513 }
514
515 #[tokio::test]
516 async fn persists_refreshed_token_to_disk() {
517 let mut mocks = MockSet::new();
518 mocks.mock(|when, then| {
519 when.post().path("/oauth/token");
520 then.json(refresh_response_json("refreshed-token"));
521 });
522 let server = start_server(mocks).await;
523 let dir = tempfile::tempdir().unwrap();
524 let strategy =
525 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
526
527 let _ = strategy.get_token().await.unwrap();
528
529 let store = ProfileStore::new(dir.path());
531 let ws_store = store.current_workspace_store().unwrap();
532 let on_disk: Token = ws_store.load_profile().unwrap();
533 assert_eq!(
534 on_disk.access_token().as_str(),
535 "refreshed-token",
536 "refreshed token should be persisted to disk"
537 );
538 }
539
540 #[tokio::test]
541 async fn returns_expired_on_refresh_failure() {
542 let mut mocks = MockSet::new();
543 mocks.mock(|when, then| {
544 when.post().path("/oauth/token");
545 then.bad_request().json(error_json("invalid_grant"));
546 });
547 let server = start_server(mocks).await;
548 let dir = tempfile::tempdir().unwrap();
549 let strategy =
550 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
551
552 let err = strategy.get_token().await.unwrap_err();
553
554 assert!(
555 matches!(err, AutoRefreshError::Expired),
556 "expected Expired after failed refresh, got: {err:?}"
557 );
558 }
559
560 #[tokio::test]
561 async fn restores_refresh_token_after_failure() {
562 let mut mocks = MockSet::new();
563 mocks.mock(|when, then| {
564 when.post().path("/oauth/token");
565 then.bad_request().json(error_json("invalid_grant"));
566 });
567 let server = start_server(mocks).await;
568 let dir = tempfile::tempdir().unwrap();
569 let strategy =
570 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
571
572 let err = strategy.get_token().await.unwrap_err();
574 assert!(
575 matches!(err, AutoRefreshError::Expired),
576 "expected Expired on first attempt, got: {err:?}"
577 );
578
579 let state = strategy.state.lock().await;
581 assert!(
582 state.token.is_some(),
583 "token should still be cached after failed refresh"
584 );
585 assert!(
586 state.token.as_ref().unwrap().refresh_token().is_some(),
587 "refresh token should be restored for retry"
588 );
589 drop(state);
590
591 server.mocks().clear();
593 server.mocks().mock(|when, then| {
594 when.post().path("/oauth/token");
595 then.json(refresh_response_json("refreshed-token"));
596 });
597
598 let token = strategy.get_token().await.unwrap();
600 assert_eq!(
601 token.as_str(),
602 "refreshed-token",
603 "retry should succeed with restored refresh token"
604 );
605 }
606
607 #[tokio::test]
608 async fn sequential_calls_only_refresh_once() {
609 let mut mocks = MockSet::new();
610 mocks.mock(|when, then| {
611 when.post().path("/oauth/token");
612 then.json(refresh_response_json("refreshed-once"));
613 });
614 let server = start_server(mocks).await;
615 let dir = tempfile::tempdir().unwrap();
616 let strategy =
617 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
618
619 let token = strategy.get_token().await.unwrap();
621 assert_eq!(
622 token.as_str(),
623 "refreshed-once",
624 "first call should trigger refresh"
625 );
626
627 server.mocks().clear();
629 server.mocks().mock(|when, then| {
630 when.post().path("/oauth/token");
631 then.json(refresh_response_json("refreshed-twice"));
632 });
633
634 for _ in 0..4 {
636 let token = strategy.get_token().await.unwrap();
637 assert_eq!(
638 token.as_str(),
639 "refreshed-once",
640 "should return cached refreshed token, not trigger another refresh"
641 );
642 }
643 }
644
645 #[tokio::test]
646 async fn prevents_second_refresh_after_success() {
647 let mut mocks = MockSet::new();
648 mocks.mock(|when, then| {
649 when.post().path("/oauth/token");
650 then.json(refresh_response_json("refreshed-token"));
651 });
652 let server = start_server(mocks).await;
653 let dir = tempfile::tempdir().unwrap();
654 let strategy =
655 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
656
657 let token = strategy.get_token().await.unwrap();
659 assert_eq!(
660 token.as_str(),
661 "refreshed-token",
662 "first call should refresh the token"
663 );
664
665 server.mocks().clear();
667 server.mocks().mock(|when, then| {
668 when.post().path("/oauth/token");
669 then.bad_request().json(error_json("should_not_be_called"));
670 });
671
672 let token = strategy.get_token().await.unwrap();
675 assert_eq!(
676 token.as_str(),
677 "refreshed-token",
678 "second call should return cached refreshed token"
679 );
680 }
681 }
682 }
683
684 mod given_expiring_but_usable_token {
685 use super::*;
686
687 mod when_refresh_fails {
688 use super::*;
689
690 #[tokio::test]
691 async fn returns_current_token() {
692 let mut mocks = MockSet::new();
693 mocks.mock(|when, then| {
694 when.post().path("/oauth/token");
695 then.bad_request().json(error_json("server_error"));
696 });
697 let server = start_server(mocks).await;
698 let dir = tempfile::tempdir().unwrap();
699 let strategy =
702 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
703
704 let token = strategy.get_token().await.unwrap();
707 assert_eq!(
708 token.as_str(),
709 "still-usable",
710 "should return still-usable token despite failed refresh"
711 );
712
713 let state = strategy.state.lock().await;
715 assert!(state.token.is_some(), "token should still be cached");
716 assert_eq!(
717 state.token.as_ref().unwrap().access_token().as_str(),
718 "still-usable",
719 "access token should be unchanged after failed refresh"
720 );
721 assert!(
722 state.token.as_ref().unwrap().refresh_token().is_some(),
723 "refresh token should be restored after failed refresh"
724 );
725 }
726
727 #[tokio::test]
728 async fn restores_refresh_token_for_retry() {
729 let mut mocks = MockSet::new();
730 mocks.mock(|when, then| {
731 when.post().path("/oauth/token");
732 then.bad_request().json(error_json("server_error"));
733 });
734 let server = start_server(mocks).await;
735 let dir = tempfile::tempdir().unwrap();
736 let strategy =
738 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
739
740 let token = strategy.get_token().await.unwrap();
742 assert_eq!(
743 token.as_str(),
744 "still-usable",
745 "first call should return still-usable token"
746 );
747
748 server.mocks().clear();
750 server.mocks().mock(|when, then| {
751 when.post().path("/oauth/token");
752 then.json(refresh_response_json("refreshed-token"));
753 });
754
755 let token = strategy.get_token().await.unwrap();
757 assert!(
758 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
759 "expected old or refreshed token, got: {}",
760 token.as_str()
761 );
762
763 let state = strategy.state.lock().await;
765 assert_eq!(
766 state.token.as_ref().unwrap().access_token().as_str(),
767 "refreshed-token",
768 "cache should hold the refreshed token after retry"
769 );
770 }
771 }
772 }
773
774 mod given_concurrent_callers {
775 use super::*;
776
777 #[tokio::test]
778 async fn returns_usable_token_while_refreshing() {
779 let mut mocks = MockSet::new();
780 mocks.mock(|when, then| {
781 when.post().path("/oauth/token");
782 then.json(refresh_response_json("refreshed-token"));
783 });
784 let server = start_server(mocks).await;
785 let dir = tempfile::tempdir().unwrap();
786 let strategy = Arc::new(auto_refresh_with_token(
787 &dir,
788 &server,
789 make_token("still-usable", 30, true),
790 ));
791
792 let s1 = Arc::clone(&strategy);
793 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
794
795 let s2 = Arc::clone(&strategy);
796 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
797
798 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
799 let token_a = result_a.unwrap();
800 let token_b = result_b.unwrap();
801
802 assert!(
803 token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
804 "unexpected token_a: {}",
805 token_a.as_str()
806 );
807 assert!(
808 token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
809 "unexpected token_b: {}",
810 token_b.as_str()
811 );
812 }
813
814 #[tokio::test]
815 async fn blocks_until_refresh_completes() {
816 let mut mocks = MockSet::new();
817 mocks.mock(|when, then| {
818 when.post().path("/oauth/token");
819 then.json(refresh_response_json("refreshed-token"));
820 });
821 let server = start_server(mocks).await;
822 let dir = tempfile::tempdir().unwrap();
823 let strategy = Arc::new(auto_refresh_with_token(
824 &dir,
825 &server,
826 make_token("expired-token", 0, true),
827 ));
828
829 let s1 = Arc::clone(&strategy);
830 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
831
832 let s2 = Arc::clone(&strategy);
833 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
834
835 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
836 let token_a = result_a.unwrap();
837 let token_b = result_b.unwrap();
838
839 assert_eq!(
840 token_a.as_str(),
841 "refreshed-token",
842 "caller a should receive refreshed token"
843 );
844 assert_eq!(
845 token_b.as_str(),
846 "refreshed-token",
847 "caller b should receive refreshed token"
848 );
849 }
850 }
851}
852
853#[cfg(test)]
854#[allow(clippy::unwrap_used)]
855mod stress_tests {
856 use super::*;
857 use crate::oauth_refresher::OAuthRefresher;
858 use crate::SecretToken;
859 use stack_profile::ProfileStore;
860 use std::sync::atomic::{AtomicUsize, Ordering};
861 use std::sync::Arc;
862 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
863
864 #[derive(Clone)]
866 struct CountingState {
867 total: Arc<AtomicUsize>,
868 current: Arc<AtomicUsize>,
869 peak: Arc<AtomicUsize>,
870 }
871
872 impl CountingState {
873 fn new() -> Self {
874 Self {
875 total: Arc::new(AtomicUsize::new(0)),
876 current: Arc::new(AtomicUsize::new(0)),
877 peak: Arc::new(AtomicUsize::new(0)),
878 }
879 }
880
881 fn enter(&self) {
882 self.total.fetch_add(1, Ordering::SeqCst);
883 let prev = self.current.fetch_add(1, Ordering::SeqCst);
884 self.peak.fetch_max(prev + 1, Ordering::SeqCst);
885 }
886
887 fn exit(&self) {
888 self.current.fetch_sub(1, Ordering::SeqCst);
889 }
890
891 fn peak(&self) -> usize {
892 self.peak.load(Ordering::SeqCst)
893 }
894
895 fn total(&self) -> usize {
896 self.total.load(Ordering::SeqCst)
897 }
898 }
899
900 #[derive(Clone)]
901 struct DelayedRefreshState {
902 counting: CountingState,
903 delay: Duration,
904 }
905
906 async fn delayed_refresh_handler(
907 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
908 ) -> axum::Json<serde_json::Value> {
909 state.counting.enter();
910 tokio::time::sleep(state.delay).await;
911 state.counting.exit();
912 axum::Json(serde_json::json!({
913 "access_token": "refreshed-token",
914 "token_type": "Bearer",
915 "expires_in": 3600,
916 "refresh_token": "new-refresh-token"
917 }))
918 }
919
920 async fn delayed_error_handler(
921 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
922 ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
923 state.counting.enter();
924 tokio::time::sleep(state.delay).await;
925 state.counting.exit();
926 (
927 axum::http::StatusCode::BAD_REQUEST,
928 axum::Json(serde_json::json!({
929 "error": "invalid_grant",
930 "error_description": "invalid_grant occurred"
931 })),
932 )
933 }
934
935 async fn start_axum_server<H, T>(
936 handler: H,
937 state: DelayedRefreshState,
938 ) -> (url::Url, CountingState)
939 where
940 H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
941 T: 'static,
942 {
943 let counting = state.counting.clone();
944 let app = axum::Router::new()
945 .route("/oauth/token", axum::routing::post(handler))
946 .with_state(state);
947 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
948 let addr = listener.local_addr().unwrap();
949 tokio::spawn(async move {
950 axum::serve(listener, app).await.unwrap();
951 });
952 let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
953 (base_url, counting)
954 }
955
956 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
957 let now = SystemTime::now()
958 .duration_since(UNIX_EPOCH)
959 .unwrap()
960 .as_secs();
961
962 Token {
963 access_token: SecretToken::new(access),
964 token_type: "Bearer".to_string(),
965 expires_at: now + expires_in,
966 refresh_token: if refresh {
967 Some(SecretToken::new("test-refresh-token"))
968 } else {
969 None
970 },
971 region: None,
972 client_id: None,
973 device_instance_id: None,
974 }
975 }
976
977 fn auto_refresh_with_token(
978 dir: &tempfile::TempDir,
979 base_url: &url::Url,
980 token: Token,
981 ) -> AutoRefresh<OAuthRefresher> {
982 let store = ProfileStore::new(dir.path());
983 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
984 let ws_store = store.current_workspace_store().unwrap();
985 ws_store.save_profile(&token).unwrap();
986 let refresher = OAuthRefresher::new(
987 Some(ws_store),
988 base_url.clone(),
989 "cli",
990 "ap-southeast-2.aws",
991 None,
992 );
993 AutoRefresh::with_token(refresher, token)
994 }
995
996 const CONCURRENCY: usize = 50;
997
998 mod given_fresh_token {
999 use super::*;
1000
1001 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1002 async fn all_callers_return_immediately() {
1003 let counting = CountingState::new();
1004 let state = DelayedRefreshState {
1005 counting: counting.clone(),
1006 delay: Duration::from_millis(500),
1007 };
1008 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1009 let dir = tempfile::tempdir().unwrap();
1010 let strategy = Arc::new(auto_refresh_with_token(
1011 &dir,
1012 &base_url,
1013 make_token("fresh-token", 3600, true),
1014 ));
1015
1016 let start = Instant::now();
1017 let mut handles = Vec::with_capacity(CONCURRENCY);
1018 for _ in 0..CONCURRENCY {
1019 let s = Arc::clone(&strategy);
1020 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1021 }
1022
1023 let results: Vec<_> = {
1024 let mut results = Vec::with_capacity(handles.len());
1025 for handle in handles {
1026 results.push(handle.await.unwrap());
1027 }
1028 results
1029 };
1030 let elapsed = start.elapsed();
1031
1032 for token in &results {
1033 assert_eq!(
1034 token.as_str(),
1035 "fresh-token",
1036 "all callers should receive the fresh token"
1037 );
1038 }
1039
1040 assert!(
1041 elapsed < Duration::from_millis(200),
1042 "expected < 200ms for fresh tokens, got {:?}",
1043 elapsed
1044 );
1045 assert_eq!(stats.total(), 0, "no refresh requests should be made");
1046 }
1047 }
1048
1049 mod given_expiring_but_usable_token {
1050 use super::*;
1051
1052 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1053 async fn non_blocking_reads_during_refresh() {
1054 let counting = CountingState::new();
1055 let state = DelayedRefreshState {
1056 counting: counting.clone(),
1057 delay: Duration::from_millis(500),
1058 };
1059 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1060 let dir = tempfile::tempdir().unwrap();
1061 let strategy = Arc::new(auto_refresh_with_token(
1062 &dir,
1063 &base_url,
1064 make_token("still-usable", 30, true),
1065 ));
1066
1067 let start = Instant::now();
1068 let mut handles = Vec::with_capacity(CONCURRENCY);
1069 for _ in 0..CONCURRENCY {
1070 let s = Arc::clone(&strategy);
1071 handles.push(tokio::spawn(async move {
1072 let call_start = Instant::now();
1073 let token = s.get_token().await.unwrap();
1074 (token, call_start.elapsed())
1075 }));
1076 }
1077
1078 let results: Vec<_> = {
1079 let mut results = Vec::with_capacity(handles.len());
1080 for handle in handles {
1081 results.push(handle.await.unwrap());
1082 }
1083 results
1084 };
1085 let elapsed = start.elapsed();
1086
1087 for (token, _) in &results {
1088 assert!(
1089 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1090 "unexpected token: {}",
1091 token.as_str()
1092 );
1093 }
1094
1095 let fast_callers = results
1096 .iter()
1097 .filter(|(_, dur)| *dur < Duration::from_millis(100))
1098 .count();
1099 assert!(
1100 fast_callers >= CONCURRENCY - 1,
1101 "expected at least {} fast callers, got {} (total elapsed: {:?})",
1102 CONCURRENCY - 1,
1103 fast_callers,
1104 elapsed
1105 );
1106
1107 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1108 assert_eq!(stats.total(), 1, "total refresh requests");
1109 }
1110
1111 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1116 async fn waiters_receive_token_when_expiry_crosses() {
1117 let refresh_delay = Duration::from_millis(1500);
1122 let counting = CountingState::new();
1123 let state = DelayedRefreshState {
1124 counting: counting.clone(),
1125 delay: refresh_delay,
1126 };
1127 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1128 let dir = tempfile::tempdir().unwrap();
1129 let strategy = Arc::new(auto_refresh_with_token(
1130 &dir,
1131 &base_url,
1132 make_token("expiring-soon", 1, true),
1133 ));
1134
1135 let first = strategy.get_token().await.unwrap();
1137 assert_eq!(
1138 first.as_str(),
1139 "expiring-soon",
1140 "first caller should receive the expiring token"
1141 );
1142
1143 tokio::time::sleep(Duration::from_millis(1100)).await;
1145
1146 let mut handles = Vec::with_capacity(CONCURRENCY);
1149 for _ in 0..CONCURRENCY {
1150 let s = Arc::clone(&strategy);
1151 handles.push(tokio::spawn(async move { s.get_token().await }));
1152 }
1153
1154 let results: Vec<_> = {
1155 let mut results = Vec::with_capacity(handles.len());
1156 for handle in handles {
1157 results.push(handle.await.unwrap());
1158 }
1159 results
1160 };
1161
1162 for (i, result) in results.iter().enumerate() {
1164 assert!(
1165 result.is_ok(),
1166 "caller {i} got Err({:?}), expected Ok",
1167 result.as_ref().unwrap_err()
1168 );
1169 assert_eq!(
1170 result.as_ref().unwrap().as_str(),
1171 "refreshed-token",
1172 "caller {i} should receive the refreshed token"
1173 );
1174 }
1175
1176 assert_eq!(stats.total(), 1, "only one refresh request should be made");
1177 }
1178 }
1179
1180 mod given_fully_expired_token {
1181 use super::*;
1182
1183 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1184 async fn all_callers_block_until_refresh() {
1185 let refresh_delay = Duration::from_millis(200);
1186 let counting = CountingState::new();
1187 let state = DelayedRefreshState {
1188 counting: counting.clone(),
1189 delay: refresh_delay,
1190 };
1191 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1192 let dir = tempfile::tempdir().unwrap();
1193 let strategy = Arc::new(auto_refresh_with_token(
1194 &dir,
1195 &base_url,
1196 make_token("expired-token", 0, true),
1197 ));
1198
1199 let start = Instant::now();
1200 let mut handles = Vec::with_capacity(CONCURRENCY);
1201 for _ in 0..CONCURRENCY {
1202 let s = Arc::clone(&strategy);
1203 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1204 }
1205
1206 let results: Vec<_> = {
1207 let mut results = Vec::with_capacity(handles.len());
1208 for handle in handles {
1209 results.push(handle.await.unwrap());
1210 }
1211 results
1212 };
1213 let elapsed = start.elapsed();
1214
1215 for token in &results {
1216 assert_eq!(
1217 token.as_str(),
1218 "refreshed-token",
1219 "all callers should receive refreshed token"
1220 );
1221 }
1222
1223 assert!(
1224 elapsed < refresh_delay + Duration::from_millis(200),
1225 "expected < {:?} for blocked callers, got {:?}",
1226 refresh_delay + Duration::from_millis(200),
1227 elapsed
1228 );
1229
1230 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1231 assert_eq!(stats.total(), 1, "total refresh requests");
1232 }
1233
1234 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1235 async fn all_callers_receive_expired_on_failure() {
1236 let counting = CountingState::new();
1237 let state = DelayedRefreshState {
1238 counting: counting.clone(),
1239 delay: Duration::from_millis(10),
1240 };
1241 let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1242 let dir = tempfile::tempdir().unwrap();
1243 let strategy = Arc::new(auto_refresh_with_token(
1244 &dir,
1245 &base_url,
1246 make_token("expired-token", 0, true),
1247 ));
1248
1249 let mut handles = Vec::with_capacity(CONCURRENCY);
1250 for _ in 0..CONCURRENCY {
1251 let s = Arc::clone(&strategy);
1252 handles.push(tokio::spawn(async move { s.get_token().await }));
1253 }
1254
1255 let results: Vec<_> = {
1256 let mut results = Vec::with_capacity(handles.len());
1257 for handle in handles {
1258 results.push(handle.await.unwrap());
1259 }
1260 results
1261 };
1262
1263 for result in &results {
1264 assert!(result.is_err(), "expected Expired error, got Ok");
1265 let err = result.as_ref().unwrap_err();
1266 assert!(
1267 matches!(err, AutoRefreshError::Expired),
1268 "expected Expired, got: {err:?}"
1269 );
1270 }
1271
1272 let state = strategy.state.lock().await;
1273 assert!(
1274 state.token.as_ref().unwrap().refresh_token().is_some(),
1275 "refresh token should be restored after failed refresh"
1276 );
1277 drop(state);
1278
1279 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1280 assert!(
1281 stats.total() >= 1,
1282 "at least one refresh attempt should be made"
1283 );
1284 }
1285
1286 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1287 async fn retry_succeeds_after_failure() {
1288 let counting1 = CountingState::new();
1290 let state1 = DelayedRefreshState {
1291 counting: counting1.clone(),
1292 delay: Duration::from_millis(50),
1293 };
1294 let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1295 let dir = tempfile::tempdir().unwrap();
1296 let strategy = Arc::new(auto_refresh_with_token(
1297 &dir,
1298 &base_url,
1299 make_token("expired-token", 0, true),
1300 ));
1301
1302 let mut handles = Vec::with_capacity(CONCURRENCY);
1303 for _ in 0..CONCURRENCY {
1304 let s = Arc::clone(&strategy);
1305 handles.push(tokio::spawn(async move { s.get_token().await }));
1306 }
1307
1308 let results: Vec<_> = {
1309 let mut results = Vec::with_capacity(handles.len());
1310 for handle in handles {
1311 results.push(handle.await.unwrap());
1312 }
1313 results
1314 };
1315
1316 for result in &results {
1317 assert!(
1318 result.is_err(),
1319 "first wave: expected Expired, got Ok({})",
1320 result.as_ref().unwrap().as_str()
1321 );
1322 }
1323
1324 let counting2 = CountingState::new();
1326 let state2 = DelayedRefreshState {
1327 counting: counting2.clone(),
1328 delay: Duration::from_millis(50),
1329 };
1330 let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1331
1332 let strategy2 = Arc::new(auto_refresh_with_token(
1333 &dir,
1334 &base_url2,
1335 make_token("expired-token", 0, true),
1336 ));
1337
1338 let mut handles = Vec::with_capacity(CONCURRENCY);
1339 for _ in 0..CONCURRENCY {
1340 let s = Arc::clone(&strategy2);
1341 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1342 }
1343
1344 let results: Vec<_> = {
1345 let mut results = Vec::with_capacity(handles.len());
1346 for handle in handles {
1347 results.push(handle.await.unwrap());
1348 }
1349 results
1350 };
1351
1352 for token in &results {
1353 assert_eq!(
1354 token.as_str(),
1355 "refreshed-token",
1356 "retry callers should receive refreshed token"
1357 );
1358 }
1359
1360 assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1361 }
1362 }
1363
1364 mod given_cancelled_refresh {
1365 use super::*;
1366
1367 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1371 async fn blocked_callers_recover_after_cancellation() {
1372 let counting = CountingState::new();
1373 let state = DelayedRefreshState {
1374 counting: counting.clone(),
1375 delay: Duration::from_secs(10), };
1377 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1378 let dir = tempfile::tempdir().unwrap();
1379 let strategy = Arc::new(auto_refresh_with_token(
1380 &dir,
1381 &base_url,
1382 make_token("expired-token", 0, true),
1383 ));
1384
1385 let s = Arc::clone(&strategy);
1387 let handle = tokio::spawn(async move { s.get_token().await });
1388 tokio::time::sleep(Duration::from_millis(100)).await;
1389
1390 handle.abort();
1392 let _ = handle.await;
1393
1394 let s = Arc::clone(&strategy);
1398 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1399
1400 assert!(
1401 result.is_ok(),
1402 "get_token() should not hang after cancelled blocking refresh"
1403 );
1404 }
1405
1406 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1410 async fn non_blocking_callers_recover_after_cancellation() {
1411 let counting = CountingState::new();
1412 let state = DelayedRefreshState {
1413 counting: counting.clone(),
1414 delay: Duration::from_secs(10), };
1416 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1417 let dir = tempfile::tempdir().unwrap();
1418 let strategy = Arc::new(auto_refresh_with_token(
1420 &dir,
1421 &base_url,
1422 make_token("still-usable", 30, true),
1423 ));
1424
1425 let s = Arc::clone(&strategy);
1428 let handle = tokio::spawn(async move { s.get_token().await });
1429 tokio::time::sleep(Duration::from_millis(100)).await;
1430
1431 handle.abort();
1433 let _ = handle.await;
1434
1435 let s = Arc::clone(&strategy);
1438 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1439
1440 assert!(
1441 result.is_ok(),
1442 "get_token() should not hang after cancelled non-blocking refresh"
1443 );
1444 let result = result.unwrap();
1445 assert!(
1446 result.is_ok(),
1447 "expected Ok with still-usable token, got: {:?}",
1448 result.unwrap_err()
1449 );
1450 }
1451 }
1452}