1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::token_store::{NoStore, TokenStore};
7use crate::{ServiceToken, Token};
8
9#[derive(Debug, thiserror::Error)]
14pub(crate) enum AutoRefreshError {
15 #[error("No token found")]
17 NotFound,
18 #[error("Token has expired")]
20 Expired,
21 #[error("Auth error: {0}")]
23 Auth(#[from] crate::AuthError),
24}
25
26impl From<AutoRefreshError> for crate::AuthError {
27 fn from(err: AutoRefreshError) -> Self {
28 match err {
29 AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
30 AutoRefreshError::Expired => crate::AuthError::TokenExpired,
31 AutoRefreshError::Auth(e) => e,
32 }
33 }
34}
35
36pub(crate) struct AutoRefresh<R, S = NoStore> {
43 refresher: R,
44 state: Mutex<State>,
45 store: S,
46 refresh_in_progress: AtomicBool,
52 refresh_notify: Notify,
53}
54
55struct State {
56 token: Option<Token>,
57}
58
59struct CancelGuard<'a> {
65 in_progress: &'a AtomicBool,
66 notify: &'a Notify,
67 defused: bool,
68}
69
70impl Drop for CancelGuard<'_> {
71 fn drop(&mut self) {
72 if !self.defused {
73 self.in_progress.store(false, Ordering::Release);
74 self.notify.notify_waiters();
75 }
76 }
77}
78
79impl CancelGuard<'_> {
80 fn defuse(&mut self) {
81 self.defused = true;
82 }
83}
84
85impl State {
86 fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
87 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
88 Ok(ServiceToken::new(token.access_token().clone()))
89 }
90
91 fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
92 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
93 if token.is_usable() {
94 Ok(ServiceToken::new(token.access_token().clone()))
95 } else {
96 Err(AutoRefreshError::Expired)
97 }
98 }
99}
100
101impl<R> AutoRefresh<R, NoStore> {
102 pub(crate) fn with_token(refresher: R, token: Token) -> Self {
107 Self {
108 refresher,
109 state: Mutex::new(State { token: Some(token) }),
110 store: NoStore,
111 refresh_in_progress: AtomicBool::new(false),
112 refresh_notify: Notify::new(),
113 }
114 }
115}
116
117impl<R, S: TokenStore> AutoRefresh<R, S> {
118 pub(crate) fn with_store(refresher: R, store: S) -> Self {
126 Self {
127 refresher,
128 state: Mutex::new(State { token: None }),
129 store,
130 refresh_in_progress: AtomicBool::new(false),
131 refresh_notify: Notify::new(),
132 }
133 }
134}
135
136impl<R: Refresher, S: TokenStore> AutoRefresh<R, S> {
137 pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
139 let mut state = self.state.lock().await;
140
141 if state.token.is_none() {
142 drop(state);
147 let loaded = self.store.load().await;
148 state = self.state.lock().await;
149 if state.token.is_none() {
150 state.token = loaded;
151 }
152 }
153
154 if state.token.is_none() {
155 return self.initial_auth(&mut state).await;
156 }
157
158 if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
159 return state.service_token();
160 }
161
162 if self.refresh_in_progress.load(Ordering::Acquire) {
163 return self.wait_for_in_flight_refresh(state).await;
164 }
165
166 let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
167 return state.require_usable_token();
168 };
169
170 self.refresh_in_progress.store(true, Ordering::Release);
171
172 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
173 self.refresh_non_blocking(state, credential).await
174 } else {
175 self.refresh_blocking(&mut state, credential).await
176 }
177 }
178
179 async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
183 let Some(credential) = self.refresher.try_credential(None) else {
184 return Err(AutoRefreshError::NotFound);
185 };
186 self.refresh_in_progress.store(true, Ordering::Release);
187 let mut guard = CancelGuard {
188 in_progress: &self.refresh_in_progress,
189 notify: &self.refresh_notify,
190 defused: false,
191 };
192 match self.refresher.refresh(&credential).await {
193 Ok(new_token) => {
194 guard.defuse();
195 self.save_refreshed_token(&new_token).await;
196 Ok(self.install_refreshed_token(state, new_token))
197 }
198 Err(err) => {
199 guard.defuse();
200 self.refresh_in_progress.store(false, Ordering::Release);
201 Err(AutoRefreshError::Auth(err))
202 }
203 }
204 }
205
206 async fn save_refreshed_token(&self, new_token: &Token) {
212 self.refresher.save(new_token);
213 self.store.save(new_token).await;
214 }
215
216 fn install_refreshed_token(&self, state: &mut State, new_token: Token) -> ServiceToken {
221 let service_token = ServiceToken::new(new_token.access_token().clone());
222 state.token = Some(new_token);
223 self.refresh_in_progress.store(false, Ordering::Release);
224 service_token
225 }
226
227 async fn wait_for_in_flight_refresh(
233 &self,
234 state: MutexGuard<'_, State>,
235 ) -> Result<ServiceToken, AutoRefreshError> {
236 if let Ok(token) = state.service_token() {
237 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
238 return Ok(token);
239 }
240 }
241 let notified = self.refresh_notify.notified();
244 drop(state);
245 notified.await;
246 let state = self.state.lock().await;
248 state.require_usable_token()
249 }
250
251 async fn refresh_non_blocking(
261 &self,
262 state: MutexGuard<'_, State>,
263 credential: R::Credential,
264 ) -> Result<ServiceToken, AutoRefreshError> {
265 let current_service_token = state.service_token()?;
266 drop(state);
267
268 let mut guard = CancelGuard {
269 in_progress: &self.refresh_in_progress,
270 notify: &self.refresh_notify,
271 defused: false,
272 };
273
274 match self.refresher.refresh(&credential).await {
275 Ok(new_token) => {
276 guard.defuse();
277 self.save_refreshed_token(&new_token).await;
278 let mut state = self.state.lock().await;
279 let _ = self.install_refreshed_token(&mut state, new_token);
280 }
281 Err(err) => {
282 guard.defuse();
283 tracing::warn!(%err, "token refresh failed (token still usable)");
284 let mut state = self.state.lock().await;
285 if let Some(token) = state.token.as_mut() {
286 self.refresher.restore(token, credential);
287 }
288 self.refresh_in_progress.store(false, Ordering::Release);
289 }
290 }
291
292 self.refresh_notify.notify_waiters();
293 Ok(current_service_token)
294 }
295
296 async fn refresh_blocking(
305 &self,
306 state: &mut State,
307 credential: R::Credential,
308 ) -> Result<ServiceToken, AutoRefreshError> {
309 let mut guard = CancelGuard {
310 in_progress: &self.refresh_in_progress,
311 notify: &self.refresh_notify,
312 defused: false,
313 };
314 match self.refresher.refresh(&credential).await {
315 Ok(new_token) => {
316 guard.defuse();
317 self.save_refreshed_token(&new_token).await;
318 Ok(self.install_refreshed_token(state, new_token))
319 }
320 Err(err) => {
321 guard.defuse();
322 tracing::warn!(%err, "token refresh failed");
323 if let Some(token) = state.token.as_mut() {
324 self.refresher.restore(token, credential);
325 }
326 self.refresh_in_progress.store(false, Ordering::Release);
327 Err(AutoRefreshError::Expired)
328 }
329 }
330 }
331}
332
333#[cfg(test)]
334#[allow(clippy::unwrap_used)]
335mod tests {
336 use super::*;
337 use crate::oauth_refresher::OAuthRefresher;
338 use crate::SecretToken;
339 use mocktail::prelude::*;
340 use stack_profile::ProfileStore;
341 use std::sync::Arc;
342 use std::time::{SystemTime, UNIX_EPOCH};
343
344 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
345 let now = SystemTime::now()
346 .duration_since(UNIX_EPOCH)
347 .unwrap()
348 .as_secs();
349
350 Token {
351 access_token: SecretToken::new(access),
352 token_type: "Bearer".to_string(),
353 expires_at: now + expires_in,
354 refresh_token: if refresh {
355 Some(SecretToken::new("test-refresh-token"))
356 } else {
357 None
358 },
359 region: None,
360 client_id: None,
361 device_instance_id: None,
362 }
363 }
364
365 fn refresh_response_json(access: &str) -> serde_json::Value {
366 serde_json::json!({
367 "access_token": access,
368 "token_type": "Bearer",
369 "expires_in": 3600,
370 "refresh_token": "new-refresh-token"
371 })
372 }
373
374 fn error_json(error: &str) -> serde_json::Value {
375 serde_json::json!({
376 "error": error,
377 "error_description": format!("{error} occurred")
378 })
379 }
380
381 async fn start_server(mocks: MockSet) -> MockServer {
382 let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
383 server.start().await.unwrap();
384 server
385 }
386
387 fn auto_refresh_with_token(
388 dir: &tempfile::TempDir,
389 server: &MockServer,
390 token: Token,
391 ) -> AutoRefresh<OAuthRefresher> {
392 let store = ProfileStore::new(dir.path());
393 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
394 let ws_store = store.current_workspace_store().unwrap();
395 ws_store.save_profile(&token).unwrap();
396 let refresher = OAuthRefresher::new(
397 Some(ws_store),
398 server.url(""),
399 "cli",
400 "ap-southeast-2.aws",
401 None,
402 );
403 AutoRefresh::with_token(refresher, token)
404 }
405
406 mod given_no_cached_token {
407 use super::*;
408
409 #[tokio::test]
410 async fn returns_not_found_for_oauth() {
411 let server = start_server(MockSet::new()).await;
412 let store = ProfileStore::new("/tmp/nonexistent");
413 let refresher = OAuthRefresher::new(
414 Some(store),
415 server.url(""),
416 "cli",
417 "ap-southeast-2.aws",
418 None,
419 );
420 let strategy = AutoRefresh::with_store(refresher, NoStore);
421
422 let err = strategy.get_token().await.unwrap_err();
423
424 assert!(
425 matches!(err, AutoRefreshError::NotFound),
426 "expected NotFound, got: {err:?}"
427 );
428 }
429 }
430
431 mod given_fresh_token {
432 use super::*;
433
434 #[tokio::test]
435 async fn returns_cached_token() {
436 let dir = tempfile::tempdir().unwrap();
437 let server = start_server(MockSet::new()).await;
438 let strategy =
439 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
440
441 let token = strategy.get_token().await.unwrap();
442
443 assert_eq!(
444 token.as_str(),
445 "my-access-token",
446 "should return the cached access token"
447 );
448 }
449
450 #[tokio::test]
451 async fn caches_across_calls() {
452 let dir = tempfile::tempdir().unwrap();
453 let server = start_server(MockSet::new()).await;
454 let strategy =
455 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
456
457 let token1 = strategy.get_token().await.unwrap();
458 assert_eq!(
459 token1.as_str(),
460 "my-access-token",
461 "first call should return the cached token"
462 );
463
464 std::fs::remove_file(
466 dir.path()
467 .join("workspaces")
468 .join("ZVATKW3VHMFG27DY")
469 .join("auth.json"),
470 )
471 .unwrap();
472
473 let token2 = strategy.get_token().await.unwrap();
474 assert_eq!(
475 token2.as_str(),
476 "my-access-token",
477 "second call should return the cached token even after file deletion"
478 );
479 }
480
481 #[tokio::test]
482 async fn does_not_trigger_refresh() {
483 let mut mocks = MockSet::new();
485 mocks.mock(|when, then| {
486 when.post().path("/oauth/token");
487 then.internal_server_error()
488 .json(error_json("should_not_be_called"));
489 });
490 let server = start_server(mocks).await;
491 let dir = tempfile::tempdir().unwrap();
492 let strategy =
493 auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
494
495 let token = strategy.get_token().await.unwrap();
496
497 assert_eq!(
498 token.as_str(),
499 "fresh-token",
500 "should return fresh token without triggering refresh"
501 );
502 }
503 }
504
505 mod given_fully_expired_token {
506 use super::*;
507
508 mod without_refresh_token {
509 use super::*;
510
511 #[tokio::test]
512 async fn returns_expired() {
513 let dir = tempfile::tempdir().unwrap();
514 let server = start_server(MockSet::new()).await;
515 let strategy =
516 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
517
518 let err = strategy.get_token().await.unwrap_err();
519
520 assert!(
521 matches!(err, AutoRefreshError::Expired),
522 "expected Expired, got: {err:?}"
523 );
524 }
525 }
526
527 mod with_refresh_token {
528 use super::*;
529
530 #[tokio::test]
531 async fn refreshes_and_returns_new_token() {
532 let mut mocks = MockSet::new();
533 mocks.mock(|when, then| {
534 when.post().path("/oauth/token");
535 then.json(refresh_response_json("refreshed-token"));
536 });
537 let server = start_server(mocks).await;
538 let dir = tempfile::tempdir().unwrap();
539 let strategy =
540 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
541
542 let token = strategy.get_token().await.unwrap();
543
544 assert_eq!(
545 token.as_str(),
546 "refreshed-token",
547 "should return the refreshed token"
548 );
549 }
550
551 #[tokio::test]
552 async fn persists_refreshed_token_to_disk() {
553 let mut mocks = MockSet::new();
554 mocks.mock(|when, then| {
555 when.post().path("/oauth/token");
556 then.json(refresh_response_json("refreshed-token"));
557 });
558 let server = start_server(mocks).await;
559 let dir = tempfile::tempdir().unwrap();
560 let strategy =
561 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
562
563 let _ = strategy.get_token().await.unwrap();
564
565 let store = ProfileStore::new(dir.path());
567 let ws_store = store.current_workspace_store().unwrap();
568 let on_disk: Token = ws_store.load_profile().unwrap();
569 assert_eq!(
570 on_disk.access_token().as_str(),
571 "refreshed-token",
572 "refreshed token should be persisted to disk"
573 );
574 }
575
576 #[tokio::test]
577 async fn returns_expired_on_refresh_failure() {
578 let mut mocks = MockSet::new();
579 mocks.mock(|when, then| {
580 when.post().path("/oauth/token");
581 then.bad_request().json(error_json("invalid_grant"));
582 });
583 let server = start_server(mocks).await;
584 let dir = tempfile::tempdir().unwrap();
585 let strategy =
586 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
587
588 let err = strategy.get_token().await.unwrap_err();
589
590 assert!(
591 matches!(err, AutoRefreshError::Expired),
592 "expected Expired after failed refresh, got: {err:?}"
593 );
594 }
595
596 #[tokio::test]
597 async fn restores_refresh_token_after_failure() {
598 let mut mocks = MockSet::new();
599 mocks.mock(|when, then| {
600 when.post().path("/oauth/token");
601 then.bad_request().json(error_json("invalid_grant"));
602 });
603 let server = start_server(mocks).await;
604 let dir = tempfile::tempdir().unwrap();
605 let strategy =
606 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
607
608 let err = strategy.get_token().await.unwrap_err();
610 assert!(
611 matches!(err, AutoRefreshError::Expired),
612 "expected Expired on first attempt, got: {err:?}"
613 );
614
615 let state = strategy.state.lock().await;
617 assert!(
618 state.token.is_some(),
619 "token should still be cached after failed refresh"
620 );
621 assert!(
622 state.token.as_ref().unwrap().refresh_token().is_some(),
623 "refresh token should be restored for retry"
624 );
625 drop(state);
626
627 server.mocks().clear();
629 server.mocks().mock(|when, then| {
630 when.post().path("/oauth/token");
631 then.json(refresh_response_json("refreshed-token"));
632 });
633
634 let token = strategy.get_token().await.unwrap();
636 assert_eq!(
637 token.as_str(),
638 "refreshed-token",
639 "retry should succeed with restored refresh token"
640 );
641 }
642
643 #[tokio::test]
644 async fn sequential_calls_only_refresh_once() {
645 let mut mocks = MockSet::new();
646 mocks.mock(|when, then| {
647 when.post().path("/oauth/token");
648 then.json(refresh_response_json("refreshed-once"));
649 });
650 let server = start_server(mocks).await;
651 let dir = tempfile::tempdir().unwrap();
652 let strategy =
653 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
654
655 let token = strategy.get_token().await.unwrap();
657 assert_eq!(
658 token.as_str(),
659 "refreshed-once",
660 "first call should trigger refresh"
661 );
662
663 server.mocks().clear();
665 server.mocks().mock(|when, then| {
666 when.post().path("/oauth/token");
667 then.json(refresh_response_json("refreshed-twice"));
668 });
669
670 for _ in 0..4 {
672 let token = strategy.get_token().await.unwrap();
673 assert_eq!(
674 token.as_str(),
675 "refreshed-once",
676 "should return cached refreshed token, not trigger another refresh"
677 );
678 }
679 }
680
681 #[tokio::test]
682 async fn prevents_second_refresh_after_success() {
683 let mut mocks = MockSet::new();
684 mocks.mock(|when, then| {
685 when.post().path("/oauth/token");
686 then.json(refresh_response_json("refreshed-token"));
687 });
688 let server = start_server(mocks).await;
689 let dir = tempfile::tempdir().unwrap();
690 let strategy =
691 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
692
693 let token = strategy.get_token().await.unwrap();
695 assert_eq!(
696 token.as_str(),
697 "refreshed-token",
698 "first call should refresh the token"
699 );
700
701 server.mocks().clear();
703 server.mocks().mock(|when, then| {
704 when.post().path("/oauth/token");
705 then.bad_request().json(error_json("should_not_be_called"));
706 });
707
708 let token = strategy.get_token().await.unwrap();
711 assert_eq!(
712 token.as_str(),
713 "refreshed-token",
714 "second call should return cached refreshed token"
715 );
716 }
717 }
718 }
719
720 mod given_expiring_but_usable_token {
721 use super::*;
722
723 mod when_refresh_fails {
724 use super::*;
725
726 #[tokio::test]
727 async fn returns_current_token() {
728 let mut mocks = MockSet::new();
729 mocks.mock(|when, then| {
730 when.post().path("/oauth/token");
731 then.bad_request().json(error_json("server_error"));
732 });
733 let server = start_server(mocks).await;
734 let dir = tempfile::tempdir().unwrap();
735 let strategy =
738 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
739
740 let token = strategy.get_token().await.unwrap();
743 assert_eq!(
744 token.as_str(),
745 "still-usable",
746 "should return still-usable token despite failed refresh"
747 );
748
749 let state = strategy.state.lock().await;
751 assert!(state.token.is_some(), "token should still be cached");
752 assert_eq!(
753 state.token.as_ref().unwrap().access_token().as_str(),
754 "still-usable",
755 "access token should be unchanged after failed refresh"
756 );
757 assert!(
758 state.token.as_ref().unwrap().refresh_token().is_some(),
759 "refresh token should be restored after failed refresh"
760 );
761 }
762
763 #[tokio::test]
764 async fn restores_refresh_token_for_retry() {
765 let mut mocks = MockSet::new();
766 mocks.mock(|when, then| {
767 when.post().path("/oauth/token");
768 then.bad_request().json(error_json("server_error"));
769 });
770 let server = start_server(mocks).await;
771 let dir = tempfile::tempdir().unwrap();
772 let strategy =
774 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
775
776 let token = strategy.get_token().await.unwrap();
778 assert_eq!(
779 token.as_str(),
780 "still-usable",
781 "first call should return still-usable token"
782 );
783
784 server.mocks().clear();
786 server.mocks().mock(|when, then| {
787 when.post().path("/oauth/token");
788 then.json(refresh_response_json("refreshed-token"));
789 });
790
791 let token = strategy.get_token().await.unwrap();
793 assert!(
794 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
795 "expected old or refreshed token, got: {}",
796 token.as_str()
797 );
798
799 let state = strategy.state.lock().await;
801 assert_eq!(
802 state.token.as_ref().unwrap().access_token().as_str(),
803 "refreshed-token",
804 "cache should hold the refreshed token after retry"
805 );
806 }
807 }
808 }
809
810 mod given_concurrent_callers {
811 use super::*;
812
813 #[tokio::test]
814 async fn returns_usable_token_while_refreshing() {
815 let mut mocks = MockSet::new();
816 mocks.mock(|when, then| {
817 when.post().path("/oauth/token");
818 then.json(refresh_response_json("refreshed-token"));
819 });
820 let server = start_server(mocks).await;
821 let dir = tempfile::tempdir().unwrap();
822 let strategy = Arc::new(auto_refresh_with_token(
823 &dir,
824 &server,
825 make_token("still-usable", 30, true),
826 ));
827
828 let s1 = Arc::clone(&strategy);
829 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
830
831 let s2 = Arc::clone(&strategy);
832 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
833
834 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
835 let token_a = result_a.unwrap();
836 let token_b = result_b.unwrap();
837
838 assert!(
839 token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
840 "unexpected token_a: {}",
841 token_a.as_str()
842 );
843 assert!(
844 token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
845 "unexpected token_b: {}",
846 token_b.as_str()
847 );
848 }
849
850 #[tokio::test]
851 async fn blocks_until_refresh_completes() {
852 let mut mocks = MockSet::new();
853 mocks.mock(|when, then| {
854 when.post().path("/oauth/token");
855 then.json(refresh_response_json("refreshed-token"));
856 });
857 let server = start_server(mocks).await;
858 let dir = tempfile::tempdir().unwrap();
859 let strategy = Arc::new(auto_refresh_with_token(
860 &dir,
861 &server,
862 make_token("expired-token", 0, true),
863 ));
864
865 let s1 = Arc::clone(&strategy);
866 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
867
868 let s2 = Arc::clone(&strategy);
869 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
870
871 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
872 let token_a = result_a.unwrap();
873 let token_b = result_b.unwrap();
874
875 assert_eq!(
876 token_a.as_str(),
877 "refreshed-token",
878 "caller a should receive refreshed token"
879 );
880 assert_eq!(
881 token_b.as_str(),
882 "refreshed-token",
883 "caller b should receive refreshed token"
884 );
885 }
886 }
887}
888
889#[cfg(test)]
890#[allow(clippy::unwrap_used)]
891mod stress_tests {
892 use super::*;
893 use crate::oauth_refresher::OAuthRefresher;
894 use crate::SecretToken;
895 use stack_profile::ProfileStore;
896 use std::sync::atomic::{AtomicUsize, Ordering};
897 use std::sync::Arc;
898 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
899
900 #[derive(Clone)]
902 struct CountingState {
903 total: Arc<AtomicUsize>,
904 current: Arc<AtomicUsize>,
905 peak: Arc<AtomicUsize>,
906 }
907
908 impl CountingState {
909 fn new() -> Self {
910 Self {
911 total: Arc::new(AtomicUsize::new(0)),
912 current: Arc::new(AtomicUsize::new(0)),
913 peak: Arc::new(AtomicUsize::new(0)),
914 }
915 }
916
917 fn enter(&self) {
918 self.total.fetch_add(1, Ordering::SeqCst);
919 let prev = self.current.fetch_add(1, Ordering::SeqCst);
920 self.peak.fetch_max(prev + 1, Ordering::SeqCst);
921 }
922
923 fn exit(&self) {
924 self.current.fetch_sub(1, Ordering::SeqCst);
925 }
926
927 fn peak(&self) -> usize {
928 self.peak.load(Ordering::SeqCst)
929 }
930
931 fn total(&self) -> usize {
932 self.total.load(Ordering::SeqCst)
933 }
934 }
935
936 #[derive(Clone)]
937 struct DelayedRefreshState {
938 counting: CountingState,
939 delay: Duration,
940 }
941
942 async fn delayed_refresh_handler(
943 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
944 ) -> axum::Json<serde_json::Value> {
945 state.counting.enter();
946 tokio::time::sleep(state.delay).await;
947 state.counting.exit();
948 axum::Json(serde_json::json!({
949 "access_token": "refreshed-token",
950 "token_type": "Bearer",
951 "expires_in": 3600,
952 "refresh_token": "new-refresh-token"
953 }))
954 }
955
956 async fn delayed_error_handler(
957 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
958 ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
959 state.counting.enter();
960 tokio::time::sleep(state.delay).await;
961 state.counting.exit();
962 (
963 axum::http::StatusCode::BAD_REQUEST,
964 axum::Json(serde_json::json!({
965 "error": "invalid_grant",
966 "error_description": "invalid_grant occurred"
967 })),
968 )
969 }
970
971 async fn start_axum_server<H, T>(
972 handler: H,
973 state: DelayedRefreshState,
974 ) -> (url::Url, CountingState)
975 where
976 H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
977 T: 'static,
978 {
979 let counting = state.counting.clone();
980 let app = axum::Router::new()
981 .route("/oauth/token", axum::routing::post(handler))
982 .with_state(state);
983 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
984 let addr = listener.local_addr().unwrap();
985 tokio::spawn(async move {
986 axum::serve(listener, app).await.unwrap();
987 });
988 let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
989 (base_url, counting)
990 }
991
992 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
993 let now = SystemTime::now()
994 .duration_since(UNIX_EPOCH)
995 .unwrap()
996 .as_secs();
997
998 Token {
999 access_token: SecretToken::new(access),
1000 token_type: "Bearer".to_string(),
1001 expires_at: now + expires_in,
1002 refresh_token: if refresh {
1003 Some(SecretToken::new("test-refresh-token"))
1004 } else {
1005 None
1006 },
1007 region: None,
1008 client_id: None,
1009 device_instance_id: None,
1010 }
1011 }
1012
1013 fn auto_refresh_with_token(
1014 dir: &tempfile::TempDir,
1015 base_url: &url::Url,
1016 token: Token,
1017 ) -> AutoRefresh<OAuthRefresher> {
1018 let store = ProfileStore::new(dir.path());
1019 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
1020 let ws_store = store.current_workspace_store().unwrap();
1021 ws_store.save_profile(&token).unwrap();
1022 let refresher = OAuthRefresher::new(
1023 Some(ws_store),
1024 base_url.clone(),
1025 "cli",
1026 "ap-southeast-2.aws",
1027 None,
1028 );
1029 AutoRefresh::with_token(refresher, token)
1030 }
1031
1032 const CONCURRENCY: usize = 50;
1033
1034 mod given_fresh_token {
1035 use super::*;
1036
1037 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1038 async fn all_callers_return_immediately() {
1039 let counting = CountingState::new();
1040 let state = DelayedRefreshState {
1041 counting: counting.clone(),
1042 delay: Duration::from_millis(500),
1043 };
1044 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1045 let dir = tempfile::tempdir().unwrap();
1046 let strategy = Arc::new(auto_refresh_with_token(
1047 &dir,
1048 &base_url,
1049 make_token("fresh-token", 3600, true),
1050 ));
1051
1052 let start = Instant::now();
1053 let mut handles = Vec::with_capacity(CONCURRENCY);
1054 for _ in 0..CONCURRENCY {
1055 let s = Arc::clone(&strategy);
1056 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1057 }
1058
1059 let results: Vec<_> = {
1060 let mut results = Vec::with_capacity(handles.len());
1061 for handle in handles {
1062 results.push(handle.await.unwrap());
1063 }
1064 results
1065 };
1066 let elapsed = start.elapsed();
1067
1068 for token in &results {
1069 assert_eq!(
1070 token.as_str(),
1071 "fresh-token",
1072 "all callers should receive the fresh token"
1073 );
1074 }
1075
1076 assert!(
1077 elapsed < Duration::from_millis(200),
1078 "expected < 200ms for fresh tokens, got {:?}",
1079 elapsed
1080 );
1081 assert_eq!(stats.total(), 0, "no refresh requests should be made");
1082 }
1083 }
1084
1085 mod given_expiring_but_usable_token {
1086 use super::*;
1087
1088 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1089 async fn non_blocking_reads_during_refresh() {
1090 let counting = CountingState::new();
1091 let state = DelayedRefreshState {
1092 counting: counting.clone(),
1093 delay: Duration::from_millis(500),
1094 };
1095 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1096 let dir = tempfile::tempdir().unwrap();
1097 let strategy = Arc::new(auto_refresh_with_token(
1098 &dir,
1099 &base_url,
1100 make_token("still-usable", 30, true),
1101 ));
1102
1103 let start = Instant::now();
1104 let mut handles = Vec::with_capacity(CONCURRENCY);
1105 for _ in 0..CONCURRENCY {
1106 let s = Arc::clone(&strategy);
1107 handles.push(tokio::spawn(async move {
1108 let call_start = Instant::now();
1109 let token = s.get_token().await.unwrap();
1110 (token, call_start.elapsed())
1111 }));
1112 }
1113
1114 let results: Vec<_> = {
1115 let mut results = Vec::with_capacity(handles.len());
1116 for handle in handles {
1117 results.push(handle.await.unwrap());
1118 }
1119 results
1120 };
1121 let elapsed = start.elapsed();
1122
1123 for (token, _) in &results {
1124 assert!(
1125 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1126 "unexpected token: {}",
1127 token.as_str()
1128 );
1129 }
1130
1131 let fast_callers = results
1132 .iter()
1133 .filter(|(_, dur)| *dur < Duration::from_millis(100))
1134 .count();
1135 assert!(
1136 fast_callers >= CONCURRENCY - 1,
1137 "expected at least {} fast callers, got {} (total elapsed: {:?})",
1138 CONCURRENCY - 1,
1139 fast_callers,
1140 elapsed
1141 );
1142
1143 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1144 assert_eq!(stats.total(), 1, "total refresh requests");
1145 }
1146
1147 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1152 async fn waiters_receive_token_when_expiry_crosses() {
1153 let refresh_delay = Duration::from_millis(1500);
1158 let counting = CountingState::new();
1159 let state = DelayedRefreshState {
1160 counting: counting.clone(),
1161 delay: refresh_delay,
1162 };
1163 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1164 let dir = tempfile::tempdir().unwrap();
1165 let strategy = Arc::new(auto_refresh_with_token(
1166 &dir,
1167 &base_url,
1168 make_token("expiring-soon", 1, true),
1169 ));
1170
1171 let first = strategy.get_token().await.unwrap();
1173 assert_eq!(
1174 first.as_str(),
1175 "expiring-soon",
1176 "first caller should receive the expiring token"
1177 );
1178
1179 tokio::time::sleep(Duration::from_millis(1100)).await;
1181
1182 let mut handles = Vec::with_capacity(CONCURRENCY);
1185 for _ in 0..CONCURRENCY {
1186 let s = Arc::clone(&strategy);
1187 handles.push(tokio::spawn(async move { s.get_token().await }));
1188 }
1189
1190 let results: Vec<_> = {
1191 let mut results = Vec::with_capacity(handles.len());
1192 for handle in handles {
1193 results.push(handle.await.unwrap());
1194 }
1195 results
1196 };
1197
1198 for (i, result) in results.iter().enumerate() {
1200 assert!(
1201 result.is_ok(),
1202 "caller {i} got Err({:?}), expected Ok",
1203 result.as_ref().unwrap_err()
1204 );
1205 assert_eq!(
1206 result.as_ref().unwrap().as_str(),
1207 "refreshed-token",
1208 "caller {i} should receive the refreshed token"
1209 );
1210 }
1211
1212 assert_eq!(stats.total(), 1, "only one refresh request should be made");
1213 }
1214 }
1215
1216 mod given_fully_expired_token {
1217 use super::*;
1218
1219 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1220 async fn all_callers_block_until_refresh() {
1221 let refresh_delay = Duration::from_millis(200);
1222 let counting = CountingState::new();
1223 let state = DelayedRefreshState {
1224 counting: counting.clone(),
1225 delay: refresh_delay,
1226 };
1227 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1228 let dir = tempfile::tempdir().unwrap();
1229 let strategy = Arc::new(auto_refresh_with_token(
1230 &dir,
1231 &base_url,
1232 make_token("expired-token", 0, true),
1233 ));
1234
1235 let start = Instant::now();
1236 let mut handles = Vec::with_capacity(CONCURRENCY);
1237 for _ in 0..CONCURRENCY {
1238 let s = Arc::clone(&strategy);
1239 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1240 }
1241
1242 let results: Vec<_> = {
1243 let mut results = Vec::with_capacity(handles.len());
1244 for handle in handles {
1245 results.push(handle.await.unwrap());
1246 }
1247 results
1248 };
1249 let elapsed = start.elapsed();
1250
1251 for token in &results {
1252 assert_eq!(
1253 token.as_str(),
1254 "refreshed-token",
1255 "all callers should receive refreshed token"
1256 );
1257 }
1258
1259 assert!(
1260 elapsed < refresh_delay + Duration::from_millis(200),
1261 "expected < {:?} for blocked callers, got {:?}",
1262 refresh_delay + Duration::from_millis(200),
1263 elapsed
1264 );
1265
1266 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1267 assert_eq!(stats.total(), 1, "total refresh requests");
1268 }
1269
1270 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1271 async fn all_callers_receive_expired_on_failure() {
1272 let counting = CountingState::new();
1273 let state = DelayedRefreshState {
1274 counting: counting.clone(),
1275 delay: Duration::from_millis(10),
1276 };
1277 let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1278 let dir = tempfile::tempdir().unwrap();
1279 let strategy = Arc::new(auto_refresh_with_token(
1280 &dir,
1281 &base_url,
1282 make_token("expired-token", 0, true),
1283 ));
1284
1285 let mut handles = Vec::with_capacity(CONCURRENCY);
1286 for _ in 0..CONCURRENCY {
1287 let s = Arc::clone(&strategy);
1288 handles.push(tokio::spawn(async move { s.get_token().await }));
1289 }
1290
1291 let results: Vec<_> = {
1292 let mut results = Vec::with_capacity(handles.len());
1293 for handle in handles {
1294 results.push(handle.await.unwrap());
1295 }
1296 results
1297 };
1298
1299 for result in &results {
1300 assert!(result.is_err(), "expected Expired error, got Ok");
1301 let err = result.as_ref().unwrap_err();
1302 assert!(
1303 matches!(err, AutoRefreshError::Expired),
1304 "expected Expired, got: {err:?}"
1305 );
1306 }
1307
1308 let state = strategy.state.lock().await;
1309 assert!(
1310 state.token.as_ref().unwrap().refresh_token().is_some(),
1311 "refresh token should be restored after failed refresh"
1312 );
1313 drop(state);
1314
1315 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1316 assert!(
1317 stats.total() >= 1,
1318 "at least one refresh attempt should be made"
1319 );
1320 }
1321
1322 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1323 async fn retry_succeeds_after_failure() {
1324 let counting1 = CountingState::new();
1326 let state1 = DelayedRefreshState {
1327 counting: counting1.clone(),
1328 delay: Duration::from_millis(50),
1329 };
1330 let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1331 let dir = tempfile::tempdir().unwrap();
1332 let strategy = Arc::new(auto_refresh_with_token(
1333 &dir,
1334 &base_url,
1335 make_token("expired-token", 0, true),
1336 ));
1337
1338 let mut handles = Vec::with_capacity(CONCURRENCY);
1339 for _ in 0..CONCURRENCY {
1340 let s = Arc::clone(&strategy);
1341 handles.push(tokio::spawn(async move { s.get_token().await }));
1342 }
1343
1344 let results: Vec<_> = {
1345 let mut results = Vec::with_capacity(handles.len());
1346 for handle in handles {
1347 results.push(handle.await.unwrap());
1348 }
1349 results
1350 };
1351
1352 for result in &results {
1353 assert!(
1354 result.is_err(),
1355 "first wave: expected Expired, got Ok({})",
1356 result.as_ref().unwrap().as_str()
1357 );
1358 }
1359
1360 let counting2 = CountingState::new();
1362 let state2 = DelayedRefreshState {
1363 counting: counting2.clone(),
1364 delay: Duration::from_millis(50),
1365 };
1366 let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1367
1368 let strategy2 = Arc::new(auto_refresh_with_token(
1369 &dir,
1370 &base_url2,
1371 make_token("expired-token", 0, true),
1372 ));
1373
1374 let mut handles = Vec::with_capacity(CONCURRENCY);
1375 for _ in 0..CONCURRENCY {
1376 let s = Arc::clone(&strategy2);
1377 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1378 }
1379
1380 let results: Vec<_> = {
1381 let mut results = Vec::with_capacity(handles.len());
1382 for handle in handles {
1383 results.push(handle.await.unwrap());
1384 }
1385 results
1386 };
1387
1388 for token in &results {
1389 assert_eq!(
1390 token.as_str(),
1391 "refreshed-token",
1392 "retry callers should receive refreshed token"
1393 );
1394 }
1395
1396 assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1397 }
1398 }
1399
1400 mod given_cancelled_refresh {
1401 use super::*;
1402
1403 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1407 async fn blocked_callers_recover_after_cancellation() {
1408 let counting = CountingState::new();
1409 let state = DelayedRefreshState {
1410 counting: counting.clone(),
1411 delay: Duration::from_secs(10), };
1413 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1414 let dir = tempfile::tempdir().unwrap();
1415 let strategy = Arc::new(auto_refresh_with_token(
1416 &dir,
1417 &base_url,
1418 make_token("expired-token", 0, true),
1419 ));
1420
1421 let s = Arc::clone(&strategy);
1423 let handle = tokio::spawn(async move { s.get_token().await });
1424 tokio::time::sleep(Duration::from_millis(100)).await;
1425
1426 handle.abort();
1428 let _ = handle.await;
1429
1430 let s = Arc::clone(&strategy);
1434 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1435
1436 assert!(
1437 result.is_ok(),
1438 "get_token() should not hang after cancelled blocking refresh"
1439 );
1440 }
1441
1442 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1446 async fn non_blocking_callers_recover_after_cancellation() {
1447 let counting = CountingState::new();
1448 let state = DelayedRefreshState {
1449 counting: counting.clone(),
1450 delay: Duration::from_secs(10), };
1452 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1453 let dir = tempfile::tempdir().unwrap();
1454 let strategy = Arc::new(auto_refresh_with_token(
1456 &dir,
1457 &base_url,
1458 make_token("still-usable", 30, true),
1459 ));
1460
1461 let s = Arc::clone(&strategy);
1464 let handle = tokio::spawn(async move { s.get_token().await });
1465 tokio::time::sleep(Duration::from_millis(100)).await;
1466
1467 handle.abort();
1469 let _ = handle.await;
1470
1471 let s = Arc::clone(&strategy);
1474 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1475
1476 assert!(
1477 result.is_ok(),
1478 "get_token() should not hang after cancelled non-blocking refresh"
1479 );
1480 let result = result.unwrap();
1481 assert!(
1482 result.is_ok(),
1483 "expected Ok with still-usable token, got: {:?}",
1484 result.unwrap_err()
1485 );
1486 }
1487 }
1488}