1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::token_store::{NoStore, TokenStore};
7use crate::{ServiceToken, Token};
8
9#[derive(Debug, thiserror::Error)]
14pub(crate) enum AutoRefreshError {
15 #[error("No token found")]
17 NotFound,
18 #[error("Token has expired")]
20 Expired,
21 #[error("Auth error: {0}")]
23 Auth(#[from] crate::AuthError),
24}
25
26impl From<AutoRefreshError> for crate::AuthError {
27 fn from(err: AutoRefreshError) -> Self {
28 match err {
29 AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
30 AutoRefreshError::Expired => crate::AuthError::TokenExpired,
31 AutoRefreshError::Auth(e) => e,
32 }
33 }
34}
35
36pub(crate) struct AutoRefresh<R, S = NoStore> {
43 refresher: R,
44 state: Mutex<State>,
45 store: S,
46 refresh_in_progress: AtomicBool,
52 refresh_notify: Notify,
53}
54
55struct State {
56 token: Option<Token>,
57}
58
59struct CancelGuard<'a> {
65 in_progress: &'a AtomicBool,
66 notify: &'a Notify,
67 defused: bool,
68}
69
70impl Drop for CancelGuard<'_> {
71 fn drop(&mut self) {
72 if !self.defused {
73 self.in_progress.store(false, Ordering::Release);
74 self.notify.notify_waiters();
75 }
76 }
77}
78
79impl CancelGuard<'_> {
80 fn defuse(&mut self) {
81 self.defused = true;
82 }
83}
84
85impl State {
86 fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
87 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
88 Ok(ServiceToken::new(token.access_token().clone()))
89 }
90
91 fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
92 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
93 if token.is_usable() {
94 Ok(ServiceToken::new(token.access_token().clone()))
95 } else {
96 Err(AutoRefreshError::Expired)
97 }
98 }
99}
100
101impl<R> AutoRefresh<R, NoStore> {
102 pub(crate) fn with_token(refresher: R, token: Token) -> Self {
107 Self {
108 refresher,
109 state: Mutex::new(State { token: Some(token) }),
110 store: NoStore,
111 refresh_in_progress: AtomicBool::new(false),
112 refresh_notify: Notify::new(),
113 }
114 }
115}
116
117impl<R, S: TokenStore> AutoRefresh<R, S> {
118 pub(crate) fn with_store(refresher: R, store: S) -> Self {
126 Self {
127 refresher,
128 state: Mutex::new(State { token: None }),
129 store,
130 refresh_in_progress: AtomicBool::new(false),
131 refresh_notify: Notify::new(),
132 }
133 }
134}
135
136impl<R: Refresher, S: TokenStore> AutoRefresh<R, S> {
137 pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
139 let mut state = self.state.lock().await;
140
141 if state.token.is_none() {
142 drop(state);
147 let loaded = self.store.load().await;
148 state = self.state.lock().await;
149 if state.token.is_none() {
150 state.token = loaded;
151 }
152 }
153
154 if state.token.is_none() {
155 return self.initial_auth(&mut state).await;
156 }
157
158 if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
159 return state.service_token();
160 }
161
162 if self.refresh_in_progress.load(Ordering::Acquire) {
163 return self.wait_for_in_flight_refresh(state).await;
164 }
165
166 let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
167 return state.require_usable_token();
168 };
169
170 self.refresh_in_progress.store(true, Ordering::Release);
171
172 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
173 self.refresh_non_blocking(state, credential).await
174 } else {
175 self.refresh_blocking(&mut state, credential).await
176 }
177 }
178
179 async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
183 let Some(credential) = self.refresher.try_credential(None) else {
184 return Err(AutoRefreshError::NotFound);
185 };
186 self.refresh_in_progress.store(true, Ordering::Release);
187 let mut guard = CancelGuard {
188 in_progress: &self.refresh_in_progress,
189 notify: &self.refresh_notify,
190 defused: false,
191 };
192 match self.refresher.refresh(&credential).await {
193 Ok(new_token) => {
194 self.save_refreshed_token(&new_token).await;
195 let token = self.install_refreshed_token(state, new_token);
196 guard.defuse();
197 Ok(token)
198 }
199 Err(err) => {
200 guard.defuse();
201 self.refresh_in_progress.store(false, Ordering::Release);
202 Err(AutoRefreshError::Auth(err))
203 }
204 }
205 }
206
207 async fn save_refreshed_token(&self, new_token: &Token) {
213 self.refresher.save(new_token);
214 self.store.save(new_token).await;
215 }
216
217 fn install_refreshed_token(&self, state: &mut State, new_token: Token) -> ServiceToken {
222 let service_token = ServiceToken::new(new_token.access_token().clone());
223 state.token = Some(new_token);
224 self.refresh_in_progress.store(false, Ordering::Release);
225 service_token
226 }
227
228 async fn wait_for_in_flight_refresh(
234 &self,
235 state: MutexGuard<'_, State>,
236 ) -> Result<ServiceToken, AutoRefreshError> {
237 if let Ok(token) = state.service_token() {
238 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
239 return Ok(token);
240 }
241 }
242 let notified = self.refresh_notify.notified();
245 drop(state);
246 notified.await;
247 let state = self.state.lock().await;
249 state.require_usable_token()
250 }
251
252 async fn refresh_non_blocking(
267 &self,
268 state: MutexGuard<'_, State>,
269 credential: R::Credential,
270 ) -> Result<ServiceToken, AutoRefreshError> {
271 let current_service_token = state.service_token()?;
272 drop(state);
273
274 let mut guard = CancelGuard {
275 in_progress: &self.refresh_in_progress,
276 notify: &self.refresh_notify,
277 defused: false,
278 };
279
280 match self.refresher.refresh(&credential).await {
281 Ok(new_token) => {
282 self.save_refreshed_token(&new_token).await;
283 let mut state = self.state.lock().await;
284 let _ = self.install_refreshed_token(&mut state, new_token);
285 guard.defuse();
286 }
287 Err(err) => {
288 tracing::warn!(%err, "token refresh failed (token still usable)");
289 let mut state = self.state.lock().await;
296 if let Some(token) = state.token.as_mut() {
297 self.refresher.restore(token, credential);
298 }
299 self.refresh_in_progress.store(false, Ordering::Release);
300 guard.defuse();
301 }
302 }
303
304 self.refresh_notify.notify_waiters();
305 Ok(current_service_token)
306 }
307
308 async fn refresh_blocking(
318 &self,
319 state: &mut State,
320 credential: R::Credential,
321 ) -> Result<ServiceToken, AutoRefreshError> {
322 let mut guard = CancelGuard {
323 in_progress: &self.refresh_in_progress,
324 notify: &self.refresh_notify,
325 defused: false,
326 };
327 match self.refresher.refresh(&credential).await {
328 Ok(new_token) => {
329 self.save_refreshed_token(&new_token).await;
330 let token = self.install_refreshed_token(state, new_token);
331 guard.defuse();
332 Ok(token)
333 }
334 Err(err) => {
335 guard.defuse();
336 tracing::warn!(%err, "token refresh failed");
337 if let Some(token) = state.token.as_mut() {
338 self.refresher.restore(token, credential);
339 }
340 self.refresh_in_progress.store(false, Ordering::Release);
341 Err(AutoRefreshError::Expired)
342 }
343 }
344 }
345}
346
347#[cfg(test)]
348#[allow(clippy::unwrap_used)]
349mod tests {
350 use super::*;
351 use crate::device_session_refresher::DeviceSessionRefresher;
352 use crate::SecretToken;
353 use mocktail::prelude::*;
354 use stack_profile::ProfileStore;
355 use std::sync::Arc;
356 use std::time::{SystemTime, UNIX_EPOCH};
357
358 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
359 let now = SystemTime::now()
360 .duration_since(UNIX_EPOCH)
361 .unwrap()
362 .as_secs();
363
364 Token {
365 access_token: SecretToken::new(access),
366 token_type: "Bearer".to_string(),
367 expires_at: now + expires_in,
368 refresh_token: if refresh {
369 Some(SecretToken::new("test-refresh-token"))
370 } else {
371 None
372 },
373 region: None,
374 client_id: None,
375 device_instance_id: None,
376 }
377 }
378
379 fn refresh_response_json(access: &str) -> serde_json::Value {
380 serde_json::json!({
381 "access_token": access,
382 "token_type": "Bearer",
383 "expires_in": 3600,
384 "refresh_token": "new-refresh-token"
385 })
386 }
387
388 fn error_json(error: &str) -> serde_json::Value {
389 serde_json::json!({
390 "error": error,
391 "error_description": format!("{error} occurred")
392 })
393 }
394
395 async fn start_server(mocks: MockSet) -> MockServer {
396 let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
397 server.start().await.unwrap();
398 server
399 }
400
401 fn auto_refresh_with_token(
402 dir: &tempfile::TempDir,
403 server: &MockServer,
404 token: Token,
405 ) -> AutoRefresh<DeviceSessionRefresher> {
406 let store = ProfileStore::new(dir.path());
407 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
408 let ws_store = store.current_workspace_store().unwrap();
409 ws_store.save_profile(&token).unwrap();
410 let refresher = DeviceSessionRefresher::new(
411 Some(ws_store),
412 server.url(""),
413 "cli",
414 "ap-southeast-2.aws",
415 None,
416 );
417 AutoRefresh::with_token(refresher, token)
418 }
419
420 mod given_no_cached_token {
421 use super::*;
422
423 #[tokio::test]
424 async fn returns_not_found_for_oauth() {
425 let server = start_server(MockSet::new()).await;
426 let store = ProfileStore::new("/tmp/nonexistent");
427 let refresher = DeviceSessionRefresher::new(
428 Some(store),
429 server.url(""),
430 "cli",
431 "ap-southeast-2.aws",
432 None,
433 );
434 let strategy = AutoRefresh::with_store(refresher, NoStore);
435
436 let err = strategy.get_token().await.unwrap_err();
437
438 assert!(
439 matches!(err, AutoRefreshError::NotFound),
440 "expected NotFound, got: {err:?}"
441 );
442 }
443 }
444
445 mod given_fresh_token {
446 use super::*;
447
448 #[tokio::test]
449 async fn returns_cached_token() {
450 let dir = tempfile::tempdir().unwrap();
451 let server = start_server(MockSet::new()).await;
452 let strategy =
453 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
454
455 let token = strategy.get_token().await.unwrap();
456
457 assert_eq!(
458 token.as_str(),
459 "my-access-token",
460 "should return the cached access token"
461 );
462 }
463
464 #[tokio::test]
465 async fn caches_across_calls() {
466 let dir = tempfile::tempdir().unwrap();
467 let server = start_server(MockSet::new()).await;
468 let strategy =
469 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
470
471 let token1 = strategy.get_token().await.unwrap();
472 assert_eq!(
473 token1.as_str(),
474 "my-access-token",
475 "first call should return the cached token"
476 );
477
478 std::fs::remove_file(
480 dir.path()
481 .join("workspaces")
482 .join("ZVATKW3VHMFG27DY")
483 .join("auth.json"),
484 )
485 .unwrap();
486
487 let token2 = strategy.get_token().await.unwrap();
488 assert_eq!(
489 token2.as_str(),
490 "my-access-token",
491 "second call should return the cached token even after file deletion"
492 );
493 }
494
495 #[tokio::test]
496 async fn does_not_trigger_refresh() {
497 let mut mocks = MockSet::new();
499 mocks.mock(|when, then| {
500 when.post().path("/oauth/token");
501 then.internal_server_error()
502 .json(error_json("should_not_be_called"));
503 });
504 let server = start_server(mocks).await;
505 let dir = tempfile::tempdir().unwrap();
506 let strategy =
507 auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
508
509 let token = strategy.get_token().await.unwrap();
510
511 assert_eq!(
512 token.as_str(),
513 "fresh-token",
514 "should return fresh token without triggering refresh"
515 );
516 }
517 }
518
519 mod given_fully_expired_token {
520 use super::*;
521
522 mod without_refresh_token {
523 use super::*;
524
525 #[tokio::test]
526 async fn returns_expired() {
527 let dir = tempfile::tempdir().unwrap();
528 let server = start_server(MockSet::new()).await;
529 let strategy =
530 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
531
532 let err = strategy.get_token().await.unwrap_err();
533
534 assert!(
535 matches!(err, AutoRefreshError::Expired),
536 "expected Expired, got: {err:?}"
537 );
538 }
539 }
540
541 mod with_refresh_token {
542 use super::*;
543
544 #[tokio::test]
545 async fn refreshes_and_returns_new_token() {
546 let mut mocks = MockSet::new();
547 mocks.mock(|when, then| {
548 when.post().path("/oauth/token");
549 then.json(refresh_response_json("refreshed-token"));
550 });
551 let server = start_server(mocks).await;
552 let dir = tempfile::tempdir().unwrap();
553 let strategy =
554 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
555
556 let token = strategy.get_token().await.unwrap();
557
558 assert_eq!(
559 token.as_str(),
560 "refreshed-token",
561 "should return the refreshed token"
562 );
563 }
564
565 #[tokio::test]
566 async fn persists_refreshed_token_to_disk() {
567 let mut mocks = MockSet::new();
568 mocks.mock(|when, then| {
569 when.post().path("/oauth/token");
570 then.json(refresh_response_json("refreshed-token"));
571 });
572 let server = start_server(mocks).await;
573 let dir = tempfile::tempdir().unwrap();
574 let strategy =
575 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
576
577 let _ = strategy.get_token().await.unwrap();
578
579 let store = ProfileStore::new(dir.path());
581 let ws_store = store.current_workspace_store().unwrap();
582 let on_disk: Token = ws_store.load_profile().unwrap();
583 assert_eq!(
584 on_disk.access_token().as_str(),
585 "refreshed-token",
586 "refreshed token should be persisted to disk"
587 );
588 }
589
590 #[tokio::test]
591 async fn returns_expired_on_refresh_failure() {
592 let mut mocks = MockSet::new();
593 mocks.mock(|when, then| {
594 when.post().path("/oauth/token");
595 then.bad_request().json(error_json("invalid_grant"));
596 });
597 let server = start_server(mocks).await;
598 let dir = tempfile::tempdir().unwrap();
599 let strategy =
600 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
601
602 let err = strategy.get_token().await.unwrap_err();
603
604 assert!(
605 matches!(err, AutoRefreshError::Expired),
606 "expected Expired after failed refresh, got: {err:?}"
607 );
608 }
609
610 #[tokio::test]
611 async fn restores_refresh_token_after_failure() {
612 let mut mocks = MockSet::new();
613 mocks.mock(|when, then| {
614 when.post().path("/oauth/token");
615 then.bad_request().json(error_json("invalid_grant"));
616 });
617 let server = start_server(mocks).await;
618 let dir = tempfile::tempdir().unwrap();
619 let strategy =
620 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
621
622 let err = strategy.get_token().await.unwrap_err();
624 assert!(
625 matches!(err, AutoRefreshError::Expired),
626 "expected Expired on first attempt, got: {err:?}"
627 );
628
629 let state = strategy.state.lock().await;
631 assert!(
632 state.token.is_some(),
633 "token should still be cached after failed refresh"
634 );
635 assert!(
636 state.token.as_ref().unwrap().refresh_token().is_some(),
637 "refresh token should be restored for retry"
638 );
639 drop(state);
640
641 server.mocks().clear();
643 server.mocks().mock(|when, then| {
644 when.post().path("/oauth/token");
645 then.json(refresh_response_json("refreshed-token"));
646 });
647
648 let token = strategy.get_token().await.unwrap();
650 assert_eq!(
651 token.as_str(),
652 "refreshed-token",
653 "retry should succeed with restored refresh token"
654 );
655 }
656
657 #[tokio::test]
658 async fn sequential_calls_only_refresh_once() {
659 let mut mocks = MockSet::new();
660 mocks.mock(|when, then| {
661 when.post().path("/oauth/token");
662 then.json(refresh_response_json("refreshed-once"));
663 });
664 let server = start_server(mocks).await;
665 let dir = tempfile::tempdir().unwrap();
666 let strategy =
667 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
668
669 let token = strategy.get_token().await.unwrap();
671 assert_eq!(
672 token.as_str(),
673 "refreshed-once",
674 "first call should trigger refresh"
675 );
676
677 server.mocks().clear();
679 server.mocks().mock(|when, then| {
680 when.post().path("/oauth/token");
681 then.json(refresh_response_json("refreshed-twice"));
682 });
683
684 for _ in 0..4 {
686 let token = strategy.get_token().await.unwrap();
687 assert_eq!(
688 token.as_str(),
689 "refreshed-once",
690 "should return cached refreshed token, not trigger another refresh"
691 );
692 }
693 }
694
695 #[tokio::test]
696 async fn prevents_second_refresh_after_success() {
697 let mut mocks = MockSet::new();
698 mocks.mock(|when, then| {
699 when.post().path("/oauth/token");
700 then.json(refresh_response_json("refreshed-token"));
701 });
702 let server = start_server(mocks).await;
703 let dir = tempfile::tempdir().unwrap();
704 let strategy =
705 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
706
707 let token = strategy.get_token().await.unwrap();
709 assert_eq!(
710 token.as_str(),
711 "refreshed-token",
712 "first call should refresh the token"
713 );
714
715 server.mocks().clear();
717 server.mocks().mock(|when, then| {
718 when.post().path("/oauth/token");
719 then.bad_request().json(error_json("should_not_be_called"));
720 });
721
722 let token = strategy.get_token().await.unwrap();
725 assert_eq!(
726 token.as_str(),
727 "refreshed-token",
728 "second call should return cached refreshed token"
729 );
730 }
731 }
732 }
733
734 mod given_expiring_but_usable_token {
735 use super::*;
736
737 mod when_refresh_fails {
738 use super::*;
739
740 #[tokio::test]
741 async fn returns_current_token() {
742 let mut mocks = MockSet::new();
743 mocks.mock(|when, then| {
744 when.post().path("/oauth/token");
745 then.bad_request().json(error_json("server_error"));
746 });
747 let server = start_server(mocks).await;
748 let dir = tempfile::tempdir().unwrap();
749 let strategy =
752 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
753
754 let token = strategy.get_token().await.unwrap();
757 assert_eq!(
758 token.as_str(),
759 "still-usable",
760 "should return still-usable token despite failed refresh"
761 );
762
763 let state = strategy.state.lock().await;
765 assert!(state.token.is_some(), "token should still be cached");
766 assert_eq!(
767 state.token.as_ref().unwrap().access_token().as_str(),
768 "still-usable",
769 "access token should be unchanged after failed refresh"
770 );
771 assert!(
772 state.token.as_ref().unwrap().refresh_token().is_some(),
773 "refresh token should be restored after failed refresh"
774 );
775 }
776
777 #[tokio::test]
778 async fn restores_refresh_token_for_retry() {
779 let mut mocks = MockSet::new();
780 mocks.mock(|when, then| {
781 when.post().path("/oauth/token");
782 then.bad_request().json(error_json("server_error"));
783 });
784 let server = start_server(mocks).await;
785 let dir = tempfile::tempdir().unwrap();
786 let strategy =
788 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
789
790 let token = strategy.get_token().await.unwrap();
792 assert_eq!(
793 token.as_str(),
794 "still-usable",
795 "first call should return still-usable token"
796 );
797
798 server.mocks().clear();
800 server.mocks().mock(|when, then| {
801 when.post().path("/oauth/token");
802 then.json(refresh_response_json("refreshed-token"));
803 });
804
805 let token = strategy.get_token().await.unwrap();
807 assert!(
808 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
809 "expected old or refreshed token, got: {}",
810 token.as_str()
811 );
812
813 let state = strategy.state.lock().await;
815 assert_eq!(
816 state.token.as_ref().unwrap().access_token().as_str(),
817 "refreshed-token",
818 "cache should hold the refreshed token after retry"
819 );
820 }
821 }
822 }
823
824 mod given_concurrent_callers {
825 use super::*;
826
827 #[tokio::test]
828 async fn returns_usable_token_while_refreshing() {
829 let mut mocks = MockSet::new();
830 mocks.mock(|when, then| {
831 when.post().path("/oauth/token");
832 then.json(refresh_response_json("refreshed-token"));
833 });
834 let server = start_server(mocks).await;
835 let dir = tempfile::tempdir().unwrap();
836 let strategy = Arc::new(auto_refresh_with_token(
837 &dir,
838 &server,
839 make_token("still-usable", 30, true),
840 ));
841
842 let s1 = Arc::clone(&strategy);
843 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
844
845 let s2 = Arc::clone(&strategy);
846 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
847
848 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
849 let token_a = result_a.unwrap();
850 let token_b = result_b.unwrap();
851
852 assert!(
853 token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
854 "unexpected token_a: {}",
855 token_a.as_str()
856 );
857 assert!(
858 token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
859 "unexpected token_b: {}",
860 token_b.as_str()
861 );
862 }
863
864 #[tokio::test]
865 async fn blocks_until_refresh_completes() {
866 let mut mocks = MockSet::new();
867 mocks.mock(|when, then| {
868 when.post().path("/oauth/token");
869 then.json(refresh_response_json("refreshed-token"));
870 });
871 let server = start_server(mocks).await;
872 let dir = tempfile::tempdir().unwrap();
873 let strategy = Arc::new(auto_refresh_with_token(
874 &dir,
875 &server,
876 make_token("expired-token", 0, true),
877 ));
878
879 let s1 = Arc::clone(&strategy);
880 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
881
882 let s2 = Arc::clone(&strategy);
883 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
884
885 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
886 let token_a = result_a.unwrap();
887 let token_b = result_b.unwrap();
888
889 assert_eq!(
890 token_a.as_str(),
891 "refreshed-token",
892 "caller a should receive refreshed token"
893 );
894 assert_eq!(
895 token_b.as_str(),
896 "refreshed-token",
897 "caller b should receive refreshed token"
898 );
899 }
900 }
901}
902
903#[cfg(test)]
904#[allow(clippy::unwrap_used)]
905mod stress_tests {
906 use super::*;
907 use crate::device_session_refresher::DeviceSessionRefresher;
908 use crate::SecretToken;
909 use stack_profile::ProfileStore;
910 use std::sync::atomic::{AtomicUsize, Ordering};
911 use std::sync::Arc;
912 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
913
914 #[derive(Clone)]
916 struct CountingState {
917 total: Arc<AtomicUsize>,
918 current: Arc<AtomicUsize>,
919 peak: Arc<AtomicUsize>,
920 }
921
922 impl CountingState {
923 fn new() -> Self {
924 Self {
925 total: Arc::new(AtomicUsize::new(0)),
926 current: Arc::new(AtomicUsize::new(0)),
927 peak: Arc::new(AtomicUsize::new(0)),
928 }
929 }
930
931 fn enter(&self) {
932 self.total.fetch_add(1, Ordering::SeqCst);
933 let prev = self.current.fetch_add(1, Ordering::SeqCst);
934 self.peak.fetch_max(prev + 1, Ordering::SeqCst);
935 }
936
937 fn exit(&self) {
938 self.current.fetch_sub(1, Ordering::SeqCst);
939 }
940
941 fn peak(&self) -> usize {
942 self.peak.load(Ordering::SeqCst)
943 }
944
945 fn total(&self) -> usize {
946 self.total.load(Ordering::SeqCst)
947 }
948 }
949
950 #[derive(Clone)]
951 struct DelayedRefreshState {
952 counting: CountingState,
953 delay: Duration,
954 }
955
956 async fn delayed_refresh_handler(
957 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
958 ) -> axum::Json<serde_json::Value> {
959 state.counting.enter();
960 tokio::time::sleep(state.delay).await;
961 state.counting.exit();
962 axum::Json(serde_json::json!({
963 "access_token": "refreshed-token",
964 "token_type": "Bearer",
965 "expires_in": 3600,
966 "refresh_token": "new-refresh-token"
967 }))
968 }
969
970 async fn delayed_error_handler(
971 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
972 ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
973 state.counting.enter();
974 tokio::time::sleep(state.delay).await;
975 state.counting.exit();
976 (
977 axum::http::StatusCode::BAD_REQUEST,
978 axum::Json(serde_json::json!({
979 "error": "invalid_grant",
980 "error_description": "invalid_grant occurred"
981 })),
982 )
983 }
984
985 async fn start_axum_server<H, T>(
986 handler: H,
987 state: DelayedRefreshState,
988 ) -> (url::Url, CountingState)
989 where
990 H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
991 T: 'static,
992 {
993 let counting = state.counting.clone();
994 let app = axum::Router::new()
995 .route("/oauth/token", axum::routing::post(handler))
996 .with_state(state);
997 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
998 let addr = listener.local_addr().unwrap();
999 tokio::spawn(async move {
1000 axum::serve(listener, app).await.unwrap();
1001 });
1002 let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
1003 (base_url, counting)
1004 }
1005
1006 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
1007 let now = SystemTime::now()
1008 .duration_since(UNIX_EPOCH)
1009 .unwrap()
1010 .as_secs();
1011
1012 Token {
1013 access_token: SecretToken::new(access),
1014 token_type: "Bearer".to_string(),
1015 expires_at: now + expires_in,
1016 refresh_token: if refresh {
1017 Some(SecretToken::new("test-refresh-token"))
1018 } else {
1019 None
1020 },
1021 region: None,
1022 client_id: None,
1023 device_instance_id: None,
1024 }
1025 }
1026
1027 fn auto_refresh_with_token(
1028 dir: &tempfile::TempDir,
1029 base_url: &url::Url,
1030 token: Token,
1031 ) -> AutoRefresh<DeviceSessionRefresher> {
1032 let store = ProfileStore::new(dir.path());
1033 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
1034 let ws_store = store.current_workspace_store().unwrap();
1035 ws_store.save_profile(&token).unwrap();
1036 let refresher = DeviceSessionRefresher::new(
1037 Some(ws_store),
1038 base_url.clone(),
1039 "cli",
1040 "ap-southeast-2.aws",
1041 None,
1042 );
1043 AutoRefresh::with_token(refresher, token)
1044 }
1045
1046 const CONCURRENCY: usize = 50;
1047
1048 mod given_fresh_token {
1049 use super::*;
1050
1051 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1052 async fn all_callers_return_immediately() {
1053 let counting = CountingState::new();
1054 let state = DelayedRefreshState {
1055 counting: counting.clone(),
1056 delay: Duration::from_millis(500),
1057 };
1058 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1059 let dir = tempfile::tempdir().unwrap();
1060 let strategy = Arc::new(auto_refresh_with_token(
1061 &dir,
1062 &base_url,
1063 make_token("fresh-token", 3600, true),
1064 ));
1065
1066 let start = Instant::now();
1067 let mut handles = Vec::with_capacity(CONCURRENCY);
1068 for _ in 0..CONCURRENCY {
1069 let s = Arc::clone(&strategy);
1070 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1071 }
1072
1073 let results: Vec<_> = {
1074 let mut results = Vec::with_capacity(handles.len());
1075 for handle in handles {
1076 results.push(handle.await.unwrap());
1077 }
1078 results
1079 };
1080 let elapsed = start.elapsed();
1081
1082 for token in &results {
1083 assert_eq!(
1084 token.as_str(),
1085 "fresh-token",
1086 "all callers should receive the fresh token"
1087 );
1088 }
1089
1090 assert!(
1091 elapsed < Duration::from_millis(200),
1092 "expected < 200ms for fresh tokens, got {:?}",
1093 elapsed
1094 );
1095 assert_eq!(stats.total(), 0, "no refresh requests should be made");
1096 }
1097 }
1098
1099 mod given_expiring_but_usable_token {
1100 use super::*;
1101
1102 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1103 async fn non_blocking_reads_during_refresh() {
1104 let counting = CountingState::new();
1105 let state = DelayedRefreshState {
1106 counting: counting.clone(),
1107 delay: Duration::from_millis(500),
1108 };
1109 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1110 let dir = tempfile::tempdir().unwrap();
1111 let strategy = Arc::new(auto_refresh_with_token(
1112 &dir,
1113 &base_url,
1114 make_token("still-usable", 30, true),
1115 ));
1116
1117 let start = Instant::now();
1118 let mut handles = Vec::with_capacity(CONCURRENCY);
1119 for _ in 0..CONCURRENCY {
1120 let s = Arc::clone(&strategy);
1121 handles.push(tokio::spawn(async move {
1122 let call_start = Instant::now();
1123 let token = s.get_token().await.unwrap();
1124 (token, call_start.elapsed())
1125 }));
1126 }
1127
1128 let results: Vec<_> = {
1129 let mut results = Vec::with_capacity(handles.len());
1130 for handle in handles {
1131 results.push(handle.await.unwrap());
1132 }
1133 results
1134 };
1135 let elapsed = start.elapsed();
1136
1137 for (token, _) in &results {
1138 assert!(
1139 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1140 "unexpected token: {}",
1141 token.as_str()
1142 );
1143 }
1144
1145 let fast_callers = results
1146 .iter()
1147 .filter(|(_, dur)| *dur < Duration::from_millis(100))
1148 .count();
1149 assert!(
1150 fast_callers >= CONCURRENCY - 1,
1151 "expected at least {} fast callers, got {} (total elapsed: {:?})",
1152 CONCURRENCY - 1,
1153 fast_callers,
1154 elapsed
1155 );
1156
1157 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1158 assert_eq!(stats.total(), 1, "total refresh requests");
1159 }
1160
1161 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1166 async fn waiters_receive_token_when_expiry_crosses() {
1167 let refresh_delay = Duration::from_millis(1500);
1172 let counting = CountingState::new();
1173 let state = DelayedRefreshState {
1174 counting: counting.clone(),
1175 delay: refresh_delay,
1176 };
1177 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1178 let dir = tempfile::tempdir().unwrap();
1179 let strategy = Arc::new(auto_refresh_with_token(
1180 &dir,
1181 &base_url,
1182 make_token("expiring-soon", 1, true),
1183 ));
1184
1185 let first = strategy.get_token().await.unwrap();
1187 assert_eq!(
1188 first.as_str(),
1189 "expiring-soon",
1190 "first caller should receive the expiring token"
1191 );
1192
1193 tokio::time::sleep(Duration::from_millis(1100)).await;
1195
1196 let mut handles = Vec::with_capacity(CONCURRENCY);
1199 for _ in 0..CONCURRENCY {
1200 let s = Arc::clone(&strategy);
1201 handles.push(tokio::spawn(async move { s.get_token().await }));
1202 }
1203
1204 let results: Vec<_> = {
1205 let mut results = Vec::with_capacity(handles.len());
1206 for handle in handles {
1207 results.push(handle.await.unwrap());
1208 }
1209 results
1210 };
1211
1212 for (i, result) in results.iter().enumerate() {
1214 assert!(
1215 result.is_ok(),
1216 "caller {i} got Err({:?}), expected Ok",
1217 result.as_ref().unwrap_err()
1218 );
1219 assert_eq!(
1220 result.as_ref().unwrap().as_str(),
1221 "refreshed-token",
1222 "caller {i} should receive the refreshed token"
1223 );
1224 }
1225
1226 assert_eq!(stats.total(), 1, "only one refresh request should be made");
1227 }
1228 }
1229
1230 mod given_fully_expired_token {
1231 use super::*;
1232
1233 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1234 async fn all_callers_block_until_refresh() {
1235 let refresh_delay = Duration::from_millis(200);
1236 let counting = CountingState::new();
1237 let state = DelayedRefreshState {
1238 counting: counting.clone(),
1239 delay: refresh_delay,
1240 };
1241 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1242 let dir = tempfile::tempdir().unwrap();
1243 let strategy = Arc::new(auto_refresh_with_token(
1244 &dir,
1245 &base_url,
1246 make_token("expired-token", 0, true),
1247 ));
1248
1249 let start = Instant::now();
1250 let mut handles = Vec::with_capacity(CONCURRENCY);
1251 for _ in 0..CONCURRENCY {
1252 let s = Arc::clone(&strategy);
1253 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1254 }
1255
1256 let results: Vec<_> = {
1257 let mut results = Vec::with_capacity(handles.len());
1258 for handle in handles {
1259 results.push(handle.await.unwrap());
1260 }
1261 results
1262 };
1263 let elapsed = start.elapsed();
1264
1265 for token in &results {
1266 assert_eq!(
1267 token.as_str(),
1268 "refreshed-token",
1269 "all callers should receive refreshed token"
1270 );
1271 }
1272
1273 assert!(
1274 elapsed < refresh_delay + Duration::from_millis(200),
1275 "expected < {:?} for blocked callers, got {:?}",
1276 refresh_delay + Duration::from_millis(200),
1277 elapsed
1278 );
1279
1280 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1281 assert_eq!(stats.total(), 1, "total refresh requests");
1282 }
1283
1284 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1285 async fn all_callers_receive_expired_on_failure() {
1286 let counting = CountingState::new();
1287 let state = DelayedRefreshState {
1288 counting: counting.clone(),
1289 delay: Duration::from_millis(10),
1290 };
1291 let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1292 let dir = tempfile::tempdir().unwrap();
1293 let strategy = Arc::new(auto_refresh_with_token(
1294 &dir,
1295 &base_url,
1296 make_token("expired-token", 0, true),
1297 ));
1298
1299 let mut handles = Vec::with_capacity(CONCURRENCY);
1300 for _ in 0..CONCURRENCY {
1301 let s = Arc::clone(&strategy);
1302 handles.push(tokio::spawn(async move { s.get_token().await }));
1303 }
1304
1305 let results: Vec<_> = {
1306 let mut results = Vec::with_capacity(handles.len());
1307 for handle in handles {
1308 results.push(handle.await.unwrap());
1309 }
1310 results
1311 };
1312
1313 for result in &results {
1314 assert!(result.is_err(), "expected Expired error, got Ok");
1315 let err = result.as_ref().unwrap_err();
1316 assert!(
1317 matches!(err, AutoRefreshError::Expired),
1318 "expected Expired, got: {err:?}"
1319 );
1320 }
1321
1322 let state = strategy.state.lock().await;
1323 assert!(
1324 state.token.as_ref().unwrap().refresh_token().is_some(),
1325 "refresh token should be restored after failed refresh"
1326 );
1327 drop(state);
1328
1329 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1330 assert!(
1331 stats.total() >= 1,
1332 "at least one refresh attempt should be made"
1333 );
1334 }
1335
1336 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1337 async fn retry_succeeds_after_failure() {
1338 let counting1 = CountingState::new();
1340 let state1 = DelayedRefreshState {
1341 counting: counting1.clone(),
1342 delay: Duration::from_millis(50),
1343 };
1344 let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1345 let dir = tempfile::tempdir().unwrap();
1346 let strategy = Arc::new(auto_refresh_with_token(
1347 &dir,
1348 &base_url,
1349 make_token("expired-token", 0, true),
1350 ));
1351
1352 let mut handles = Vec::with_capacity(CONCURRENCY);
1353 for _ in 0..CONCURRENCY {
1354 let s = Arc::clone(&strategy);
1355 handles.push(tokio::spawn(async move { s.get_token().await }));
1356 }
1357
1358 let results: Vec<_> = {
1359 let mut results = Vec::with_capacity(handles.len());
1360 for handle in handles {
1361 results.push(handle.await.unwrap());
1362 }
1363 results
1364 };
1365
1366 for result in &results {
1367 assert!(
1368 result.is_err(),
1369 "first wave: expected Expired, got Ok({})",
1370 result.as_ref().unwrap().as_str()
1371 );
1372 }
1373
1374 let counting2 = CountingState::new();
1376 let state2 = DelayedRefreshState {
1377 counting: counting2.clone(),
1378 delay: Duration::from_millis(50),
1379 };
1380 let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1381
1382 let strategy2 = Arc::new(auto_refresh_with_token(
1383 &dir,
1384 &base_url2,
1385 make_token("expired-token", 0, true),
1386 ));
1387
1388 let mut handles = Vec::with_capacity(CONCURRENCY);
1389 for _ in 0..CONCURRENCY {
1390 let s = Arc::clone(&strategy2);
1391 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1392 }
1393
1394 let results: Vec<_> = {
1395 let mut results = Vec::with_capacity(handles.len());
1396 for handle in handles {
1397 results.push(handle.await.unwrap());
1398 }
1399 results
1400 };
1401
1402 for token in &results {
1403 assert_eq!(
1404 token.as_str(),
1405 "refreshed-token",
1406 "retry callers should receive refreshed token"
1407 );
1408 }
1409
1410 assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1411 }
1412 }
1413
1414 mod given_cancelled_refresh {
1415 use super::*;
1416
1417 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1421 async fn blocked_callers_recover_after_cancellation() {
1422 let counting = CountingState::new();
1423 let state = DelayedRefreshState {
1424 counting: counting.clone(),
1425 delay: Duration::from_secs(10), };
1427 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1428 let dir = tempfile::tempdir().unwrap();
1429 let strategy = Arc::new(auto_refresh_with_token(
1430 &dir,
1431 &base_url,
1432 make_token("expired-token", 0, true),
1433 ));
1434
1435 let s = Arc::clone(&strategy);
1437 let handle = tokio::spawn(async move { s.get_token().await });
1438 tokio::time::sleep(Duration::from_millis(100)).await;
1439
1440 handle.abort();
1442 let _ = handle.await;
1443
1444 let s = Arc::clone(&strategy);
1448 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1449
1450 assert!(
1451 result.is_ok(),
1452 "get_token() should not hang after cancelled blocking refresh"
1453 );
1454 }
1455
1456 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1464 async fn save_phase_cancellation_does_not_strand_in_progress_flag() {
1465 use crate::token_store::TokenStore;
1466
1467 struct SlowSaveStore {
1470 initial: tokio::sync::Mutex<Option<Token>>,
1471 delay: Duration,
1472 }
1473
1474 impl TokenStore for SlowSaveStore {
1475 async fn load(&self) -> Option<Token> {
1476 self.initial.lock().await.take()
1477 }
1478
1479 async fn save(&self, _token: &Token) {
1480 tokio::time::sleep(self.delay).await;
1481 }
1482 }
1483
1484 let counting = CountingState::new();
1486 let state = DelayedRefreshState {
1487 counting: counting.clone(),
1488 delay: Duration::from_millis(10),
1489 };
1490 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1491 let dir = tempfile::tempdir().unwrap();
1492 let store = ProfileStore::new(dir.path());
1493 store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
1494 let ws_store = store.current_workspace_store().unwrap();
1495 let refresher = DeviceSessionRefresher::new(
1496 Some(ws_store),
1497 base_url,
1498 "cli",
1499 "ap-southeast-2.aws",
1500 None,
1501 );
1502 let slow_store = SlowSaveStore {
1505 initial: tokio::sync::Mutex::new(Some(make_token("expired-token", 0, true))),
1506 delay: Duration::from_secs(10),
1507 };
1508 let strategy = Arc::new(AutoRefresh::with_store(refresher, slow_store));
1509
1510 let s = Arc::clone(&strategy);
1513 let handle = tokio::spawn(async move { s.get_token().await });
1514 tokio::time::sleep(Duration::from_millis(200)).await;
1517 handle.abort();
1518 let _ = handle.await;
1519
1520 let s = Arc::clone(&strategy);
1525 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1526
1527 assert!(
1528 result.is_ok(),
1529 "get_token() should not hang after cancellation in the save/install window"
1530 );
1531 }
1532
1533 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1537 async fn non_blocking_callers_recover_after_cancellation() {
1538 let counting = CountingState::new();
1539 let state = DelayedRefreshState {
1540 counting: counting.clone(),
1541 delay: Duration::from_secs(10), };
1543 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1544 let dir = tempfile::tempdir().unwrap();
1545 let strategy = Arc::new(auto_refresh_with_token(
1547 &dir,
1548 &base_url,
1549 make_token("still-usable", 30, true),
1550 ));
1551
1552 let s = Arc::clone(&strategy);
1555 let handle = tokio::spawn(async move { s.get_token().await });
1556 tokio::time::sleep(Duration::from_millis(100)).await;
1557
1558 handle.abort();
1560 let _ = handle.await;
1561
1562 let s = Arc::clone(&strategy);
1565 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1566
1567 assert!(
1568 result.is_ok(),
1569 "get_token() should not hang after cancelled non-blocking refresh"
1570 );
1571 let result = result.unwrap();
1572 assert!(
1573 result.is_ok(),
1574 "expected Ok with still-usable token, got: {:?}",
1575 result.unwrap_err()
1576 );
1577 }
1578 }
1579}