Skip to main content

stack_auth/
auto_refresh.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::token_store::{NoStore, TokenStore};
7use crate::{ServiceToken, Token};
8
9/// Internal errors from [`AutoRefresh::get_token`].
10///
11/// Strategy wrappers convert these into [`AuthError`](crate::AuthError) for the
12/// public API.
13#[derive(Debug, thiserror::Error)]
14pub(crate) enum AutoRefreshError {
15    /// No token is cached and the strategy cannot self-authenticate.
16    #[error("No token found")]
17    NotFound,
18    /// The token has expired and refresh failed or is unavailable.
19    #[error("Token has expired")]
20    Expired,
21    /// The refresh/auth HTTP call failed.
22    #[error("Auth error: {0}")]
23    Auth(#[from] crate::AuthError),
24}
25
26impl From<AutoRefreshError> for crate::AuthError {
27    fn from(err: AutoRefreshError) -> Self {
28        match err {
29            AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
30            AutoRefreshError::Expired => crate::AuthError::TokenExpired,
31            AutoRefreshError::Auth(e) => e,
32        }
33    }
34}
35
36/// Caches a token in memory and uses a [`Refresher`] to re-authenticate
37/// or refresh before expiry, optionally backed by an external [`TokenStore`]
38/// for persistence across short-lived strategy instances.
39///
40/// See the [crate-level documentation](crate#token-refresh) for a full
41/// description of the concurrency model and flow diagram.
42pub(crate) struct AutoRefresh<R, S = NoStore> {
43    refresher: R,
44    state: Mutex<State>,
45    store: S,
46    /// Set to `true` while a refresh HTTP call is in-flight.
47    ///
48    /// Stored as an [`AtomicBool`] rather than inside [`State`] so that
49    /// [`CancelGuard`] can reset it on future cancellation without acquiring
50    /// the mutex.
51    refresh_in_progress: AtomicBool,
52    refresh_notify: Notify,
53}
54
55struct State {
56    token: Option<Token>,
57}
58
59/// Ensures [`AutoRefresh::refresh_in_progress`] is cleared and waiters are
60/// notified if the refresh future is cancelled (dropped) before completing.
61///
62/// On the normal path (success or handled error), the guard is defused before
63/// drop so that the regular cleanup code runs instead.
64struct CancelGuard<'a> {
65    in_progress: &'a AtomicBool,
66    notify: &'a Notify,
67    defused: bool,
68}
69
70impl Drop for CancelGuard<'_> {
71    fn drop(&mut self) {
72        if !self.defused {
73            self.in_progress.store(false, Ordering::Release);
74            self.notify.notify_waiters();
75        }
76    }
77}
78
79impl CancelGuard<'_> {
80    fn defuse(&mut self) {
81        self.defused = true;
82    }
83}
84
85impl State {
86    fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
87        let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
88        Ok(ServiceToken::new(token.access_token().clone()))
89    }
90
91    fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
92        let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
93        if token.is_usable() {
94            Ok(ServiceToken::new(token.access_token().clone()))
95        } else {
96            Err(AutoRefreshError::Expired)
97        }
98    }
99}
100
101impl<R> AutoRefresh<R, NoStore> {
102    /// Create a new `AutoRefresh` with a pre-loaded token and no external store.
103    ///
104    /// Use this for refreshers that cannot self-authenticate (e.g. OAuth,
105    /// which needs a refresh token from a prior device code flow).
106    pub(crate) fn with_token(refresher: R, token: Token) -> Self {
107        Self {
108            refresher,
109            state: Mutex::new(State { token: Some(token) }),
110            store: NoStore,
111            refresh_in_progress: AtomicBool::new(false),
112            refresh_notify: Notify::new(),
113        }
114    }
115}
116
117impl<R, S: TokenStore> AutoRefresh<R, S> {
118    /// Create a new `AutoRefresh` backed by `store` and no in-memory token.
119    ///
120    /// On the first `get_token` call the store is consulted before falling
121    /// through to initial authentication via `try_credential(None)`; every
122    /// successful refresh writes the new token back via `store.save()`. Pass
123    /// [`NoStore`] for the no-external-cache case — it's the default and
124    /// elides to a zero-cost no-op.
125    pub(crate) fn with_store(refresher: R, store: S) -> Self {
126        Self {
127            refresher,
128            state: Mutex::new(State { token: None }),
129            store,
130            refresh_in_progress: AtomicBool::new(false),
131            refresh_notify: Notify::new(),
132        }
133    }
134}
135
136impl<R: Refresher, S: TokenStore> AutoRefresh<R, S> {
137    /// Retrieve a valid access token, refreshing or re-authenticating as needed.
138    pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
139        let mut state = self.state.lock().await;
140
141        if state.token.is_none() {
142            // Drop the lock for the store read so a slow user-supplied backend
143            // (cookie, KV, Redis) doesn't serialise concurrent `get_token`
144            // callers. Re-acquire and double-check `state.token.is_none()` in
145            // case another caller populated it while we awaited.
146            drop(state);
147            let loaded = self.store.load().await;
148            state = self.state.lock().await;
149            if state.token.is_none() {
150                state.token = loaded;
151            }
152        }
153
154        if state.token.is_none() {
155            return self.initial_auth(&mut state).await;
156        }
157
158        if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
159            return state.service_token();
160        }
161
162        if self.refresh_in_progress.load(Ordering::Acquire) {
163            return self.wait_for_in_flight_refresh(state).await;
164        }
165
166        let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
167            return state.require_usable_token();
168        };
169
170        self.refresh_in_progress.store(true, Ordering::Release);
171
172        if state.token.as_ref().is_some_and(|t| t.is_usable()) {
173            self.refresh_non_blocking(state, credential).await
174        } else {
175            self.refresh_blocking(&mut state, credential).await
176        }
177    }
178
179    /// No cached token — authenticate via `try_credential(None)`.
180    ///
181    /// The lock is held throughout to prevent concurrent initial-auth attempts.
182    async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
183        let Some(credential) = self.refresher.try_credential(None) else {
184            return Err(AutoRefreshError::NotFound);
185        };
186        self.refresh_in_progress.store(true, Ordering::Release);
187        let mut guard = CancelGuard {
188            in_progress: &self.refresh_in_progress,
189            notify: &self.refresh_notify,
190            defused: false,
191        };
192        match self.refresher.refresh(&credential).await {
193            Ok(new_token) => {
194                self.save_refreshed_token(&new_token).await;
195                let token = self.install_refreshed_token(state, new_token);
196                guard.defuse();
197                Ok(token)
198            }
199            Err(err) => {
200                guard.defuse();
201                self.refresh_in_progress.store(false, Ordering::Release);
202                Err(AutoRefreshError::Auth(err))
203            }
204        }
205    }
206
207    /// Persist a freshly refreshed token to the per-refresher sink and the
208    /// user-supplied `TokenStore`. Awaits the store write, so callers should
209    /// drop the state lock before invoking this where possible (the
210    /// non-blocking refresh path does; the blocking/initial paths hold the
211    /// state lock throughout by design).
212    async fn save_refreshed_token(&self, new_token: &Token) {
213        self.refresher.save(new_token);
214        self.store.save(new_token).await;
215    }
216
217    /// Install a freshly refreshed token in `state`, clear the in-progress
218    /// flag, and return the corresponding [`ServiceToken`]. Pure in-lock
219    /// work; caller is responsible for having already persisted the token via
220    /// [`save_refreshed_token`](Self::save_refreshed_token).
221    fn install_refreshed_token(&self, state: &mut State, new_token: Token) -> ServiceToken {
222        let service_token = ServiceToken::new(new_token.access_token().clone());
223        state.token = Some(new_token);
224        self.refresh_in_progress.store(false, Ordering::Release);
225        service_token
226    }
227
228    /// Another caller is already refreshing — return the current token if still
229    /// usable, otherwise wait for the in-flight refresh to complete via `Notify`.
230    ///
231    /// Takes `MutexGuard` by value because the lock is dropped before awaiting
232    /// the notification.
233    async fn wait_for_in_flight_refresh(
234        &self,
235        state: MutexGuard<'_, State>,
236    ) -> Result<ServiceToken, AutoRefreshError> {
237        if let Ok(token) = state.service_token() {
238            if state.token.as_ref().is_some_and(|t| t.is_usable()) {
239                return Ok(token);
240            }
241        }
242        // Token crossed real expiry during in-flight refresh. Wait for the
243        // refresh to complete rather than returning Expired.
244        let notified = self.refresh_notify.notified();
245        drop(state);
246        notified.await;
247        // Re-check after wake — refresh may have failed.
248        let state = self.state.lock().await;
249        state.require_usable_token()
250    }
251
252    /// Token is expiring but still usable — drop the lock, refresh in the
253    /// background of this call, and return the old (still-valid) token.
254    ///
255    /// Takes `MutexGuard` by value because the lock is dropped before the HTTP
256    /// request. Notifies waiters after the refresh completes (success or error).
257    ///
258    /// A [`CancelGuard`] ensures that if this future is cancelled at any point
259    /// before the new token is installed — including the post-HTTP save +
260    /// install window — `refresh_in_progress` is cleared and waiters are
261    /// notified, so subsequent callers don't hang in
262    /// [`wait_for_in_flight_refresh`](Self::wait_for_in_flight_refresh).
263    /// The credential is not restored on cancellation (it's already gone from
264    /// `state.token`), so the next caller will get whatever the cached token
265    /// offers — usable, expired, or absent.
266    async fn refresh_non_blocking(
267        &self,
268        state: MutexGuard<'_, State>,
269        credential: R::Credential,
270    ) -> Result<ServiceToken, AutoRefreshError> {
271        let current_service_token = state.service_token()?;
272        drop(state);
273
274        let mut guard = CancelGuard {
275            in_progress: &self.refresh_in_progress,
276            notify: &self.refresh_notify,
277            defused: false,
278        };
279
280        match self.refresher.refresh(&credential).await {
281            Ok(new_token) => {
282                self.save_refreshed_token(&new_token).await;
283                let mut state = self.state.lock().await;
284                let _ = self.install_refreshed_token(&mut state, new_token);
285                guard.defuse();
286            }
287            Err(err) => {
288                tracing::warn!(%err, "token refresh failed (token still usable)");
289                // Defer `defuse()` until after the lock acquire so the
290                // CancelGuard's Drop still fires if cancellation lands on
291                // `state.lock().await`. Without this the in-progress flag
292                // would stay set with no `notify_waiters`, wedging every
293                // subsequent caller exactly like the Ok-path bug fixed
294                // earlier in this file.
295                let mut state = self.state.lock().await;
296                if let Some(token) = state.token.as_mut() {
297                    self.refresher.restore(token, credential);
298                }
299                self.refresh_in_progress.store(false, Ordering::Release);
300                guard.defuse();
301            }
302        }
303
304        self.refresh_notify.notify_waiters();
305        Ok(current_service_token)
306    }
307
308    /// Token is fully expired — refresh while holding the lock so concurrent
309    /// callers block on `lock().await` until the new token is available.
310    ///
311    /// A [`CancelGuard`] ensures that if this future is cancelled at any point
312    /// before the new token is installed — including the post-HTTP save
313    /// window — `refresh_in_progress` is cleared and waiters are notified so
314    /// they don't hang indefinitely. (The credential is lost on cancel —
315    /// see [`CancelGuard`] docs — but subsequent callers will get `Expired`
316    /// rather than blocking forever.)
317    async fn refresh_blocking(
318        &self,
319        state: &mut State,
320        credential: R::Credential,
321    ) -> Result<ServiceToken, AutoRefreshError> {
322        let mut guard = CancelGuard {
323            in_progress: &self.refresh_in_progress,
324            notify: &self.refresh_notify,
325            defused: false,
326        };
327        match self.refresher.refresh(&credential).await {
328            Ok(new_token) => {
329                self.save_refreshed_token(&new_token).await;
330                let token = self.install_refreshed_token(state, new_token);
331                guard.defuse();
332                Ok(token)
333            }
334            Err(err) => {
335                guard.defuse();
336                tracing::warn!(%err, "token refresh failed");
337                if let Some(token) = state.token.as_mut() {
338                    self.refresher.restore(token, credential);
339                }
340                self.refresh_in_progress.store(false, Ordering::Release);
341                Err(AutoRefreshError::Expired)
342            }
343        }
344    }
345}
346
347#[cfg(test)]
348#[allow(clippy::unwrap_used)]
349mod tests {
350    use super::*;
351    use crate::oauth_refresher::OAuthRefresher;
352    use crate::SecretToken;
353    use mocktail::prelude::*;
354    use stack_profile::ProfileStore;
355    use std::sync::Arc;
356    use std::time::{SystemTime, UNIX_EPOCH};
357
358    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
359        let now = SystemTime::now()
360            .duration_since(UNIX_EPOCH)
361            .unwrap()
362            .as_secs();
363
364        Token {
365            access_token: SecretToken::new(access),
366            token_type: "Bearer".to_string(),
367            expires_at: now + expires_in,
368            refresh_token: if refresh {
369                Some(SecretToken::new("test-refresh-token"))
370            } else {
371                None
372            },
373            region: None,
374            client_id: None,
375            device_instance_id: None,
376        }
377    }
378
379    fn refresh_response_json(access: &str) -> serde_json::Value {
380        serde_json::json!({
381            "access_token": access,
382            "token_type": "Bearer",
383            "expires_in": 3600,
384            "refresh_token": "new-refresh-token"
385        })
386    }
387
388    fn error_json(error: &str) -> serde_json::Value {
389        serde_json::json!({
390            "error": error,
391            "error_description": format!("{error} occurred")
392        })
393    }
394
395    async fn start_server(mocks: MockSet) -> MockServer {
396        let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
397        server.start().await.unwrap();
398        server
399    }
400
401    fn auto_refresh_with_token(
402        dir: &tempfile::TempDir,
403        server: &MockServer,
404        token: Token,
405    ) -> AutoRefresh<OAuthRefresher> {
406        let store = ProfileStore::new(dir.path());
407        store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
408        let ws_store = store.current_workspace_store().unwrap();
409        ws_store.save_profile(&token).unwrap();
410        let refresher = OAuthRefresher::new(
411            Some(ws_store),
412            server.url(""),
413            "cli",
414            "ap-southeast-2.aws",
415            None,
416        );
417        AutoRefresh::with_token(refresher, token)
418    }
419
420    mod given_no_cached_token {
421        use super::*;
422
423        #[tokio::test]
424        async fn returns_not_found_for_oauth() {
425            let server = start_server(MockSet::new()).await;
426            let store = ProfileStore::new("/tmp/nonexistent");
427            let refresher = OAuthRefresher::new(
428                Some(store),
429                server.url(""),
430                "cli",
431                "ap-southeast-2.aws",
432                None,
433            );
434            let strategy = AutoRefresh::with_store(refresher, NoStore);
435
436            let err = strategy.get_token().await.unwrap_err();
437
438            assert!(
439                matches!(err, AutoRefreshError::NotFound),
440                "expected NotFound, got: {err:?}"
441            );
442        }
443    }
444
445    mod given_fresh_token {
446        use super::*;
447
448        #[tokio::test]
449        async fn returns_cached_token() {
450            let dir = tempfile::tempdir().unwrap();
451            let server = start_server(MockSet::new()).await;
452            let strategy =
453                auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
454
455            let token = strategy.get_token().await.unwrap();
456
457            assert_eq!(
458                token.as_str(),
459                "my-access-token",
460                "should return the cached access token"
461            );
462        }
463
464        #[tokio::test]
465        async fn caches_across_calls() {
466            let dir = tempfile::tempdir().unwrap();
467            let server = start_server(MockSet::new()).await;
468            let strategy =
469                auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
470
471            let token1 = strategy.get_token().await.unwrap();
472            assert_eq!(
473                token1.as_str(),
474                "my-access-token",
475                "first call should return the cached token"
476            );
477
478            // Delete the file — second call should still return the cached token.
479            std::fs::remove_file(
480                dir.path()
481                    .join("workspaces")
482                    .join("ZVATKW3VHMFG27DY")
483                    .join("auth.json"),
484            )
485            .unwrap();
486
487            let token2 = strategy.get_token().await.unwrap();
488            assert_eq!(
489                token2.as_str(),
490                "my-access-token",
491                "second call should return the cached token even after file deletion"
492            );
493        }
494
495        #[tokio::test]
496        async fn does_not_trigger_refresh() {
497            // Mock that would fail if hit — proves no refresh request is made.
498            let mut mocks = MockSet::new();
499            mocks.mock(|when, then| {
500                when.post().path("/oauth/token");
501                then.internal_server_error()
502                    .json(error_json("should_not_be_called"));
503            });
504            let server = start_server(mocks).await;
505            let dir = tempfile::tempdir().unwrap();
506            let strategy =
507                auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
508
509            let token = strategy.get_token().await.unwrap();
510
511            assert_eq!(
512                token.as_str(),
513                "fresh-token",
514                "should return fresh token without triggering refresh"
515            );
516        }
517    }
518
519    mod given_fully_expired_token {
520        use super::*;
521
522        mod without_refresh_token {
523            use super::*;
524
525            #[tokio::test]
526            async fn returns_expired() {
527                let dir = tempfile::tempdir().unwrap();
528                let server = start_server(MockSet::new()).await;
529                let strategy =
530                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
531
532                let err = strategy.get_token().await.unwrap_err();
533
534                assert!(
535                    matches!(err, AutoRefreshError::Expired),
536                    "expected Expired, got: {err:?}"
537                );
538            }
539        }
540
541        mod with_refresh_token {
542            use super::*;
543
544            #[tokio::test]
545            async fn refreshes_and_returns_new_token() {
546                let mut mocks = MockSet::new();
547                mocks.mock(|when, then| {
548                    when.post().path("/oauth/token");
549                    then.json(refresh_response_json("refreshed-token"));
550                });
551                let server = start_server(mocks).await;
552                let dir = tempfile::tempdir().unwrap();
553                let strategy =
554                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
555
556                let token = strategy.get_token().await.unwrap();
557
558                assert_eq!(
559                    token.as_str(),
560                    "refreshed-token",
561                    "should return the refreshed token"
562                );
563            }
564
565            #[tokio::test]
566            async fn persists_refreshed_token_to_disk() {
567                let mut mocks = MockSet::new();
568                mocks.mock(|when, then| {
569                    when.post().path("/oauth/token");
570                    then.json(refresh_response_json("refreshed-token"));
571                });
572                let server = start_server(mocks).await;
573                let dir = tempfile::tempdir().unwrap();
574                let strategy =
575                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
576
577                let _ = strategy.get_token().await.unwrap();
578
579                // Verify the refreshed token was saved to the workspace directory.
580                let store = ProfileStore::new(dir.path());
581                let ws_store = store.current_workspace_store().unwrap();
582                let on_disk: Token = ws_store.load_profile().unwrap();
583                assert_eq!(
584                    on_disk.access_token().as_str(),
585                    "refreshed-token",
586                    "refreshed token should be persisted to disk"
587                );
588            }
589
590            #[tokio::test]
591            async fn returns_expired_on_refresh_failure() {
592                let mut mocks = MockSet::new();
593                mocks.mock(|when, then| {
594                    when.post().path("/oauth/token");
595                    then.bad_request().json(error_json("invalid_grant"));
596                });
597                let server = start_server(mocks).await;
598                let dir = tempfile::tempdir().unwrap();
599                let strategy =
600                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
601
602                let err = strategy.get_token().await.unwrap_err();
603
604                assert!(
605                    matches!(err, AutoRefreshError::Expired),
606                    "expected Expired after failed refresh, got: {err:?}"
607                );
608            }
609
610            #[tokio::test]
611            async fn restores_refresh_token_after_failure() {
612                let mut mocks = MockSet::new();
613                mocks.mock(|when, then| {
614                    when.post().path("/oauth/token");
615                    then.bad_request().json(error_json("invalid_grant"));
616                });
617                let server = start_server(mocks).await;
618                let dir = tempfile::tempdir().unwrap();
619                let strategy =
620                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
621
622                // First call: refresh fails, returns Expired.
623                let err = strategy.get_token().await.unwrap_err();
624                assert!(
625                    matches!(err, AutoRefreshError::Expired),
626                    "expected Expired on first attempt, got: {err:?}"
627                );
628
629                // Verify the refresh token was restored so a retry is possible.
630                let state = strategy.state.lock().await;
631                assert!(
632                    state.token.is_some(),
633                    "token should still be cached after failed refresh"
634                );
635                assert!(
636                    state.token.as_ref().unwrap().refresh_token().is_some(),
637                    "refresh token should be restored for retry"
638                );
639                drop(state);
640
641                // Replace mock with a success response.
642                server.mocks().clear();
643                server.mocks().mock(|when, then| {
644                    when.post().path("/oauth/token");
645                    then.json(refresh_response_json("refreshed-token"));
646                });
647
648                // Second call: refresh token is available → retry succeeds.
649                let token = strategy.get_token().await.unwrap();
650                assert_eq!(
651                    token.as_str(),
652                    "refreshed-token",
653                    "retry should succeed with restored refresh token"
654                );
655            }
656
657            #[tokio::test]
658            async fn sequential_calls_only_refresh_once() {
659                let mut mocks = MockSet::new();
660                mocks.mock(|when, then| {
661                    when.post().path("/oauth/token");
662                    then.json(refresh_response_json("refreshed-once"));
663                });
664                let server = start_server(mocks).await;
665                let dir = tempfile::tempdir().unwrap();
666                let strategy =
667                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
668
669                // First call triggers refresh.
670                let token = strategy.get_token().await.unwrap();
671                assert_eq!(
672                    token.as_str(),
673                    "refreshed-once",
674                    "first call should trigger refresh"
675                );
676
677                // Swap mock to track if another refresh is attempted.
678                server.mocks().clear();
679                server.mocks().mock(|when, then| {
680                    when.post().path("/oauth/token");
681                    then.json(refresh_response_json("refreshed-twice"));
682                });
683
684                // Calls 2-5: the refreshed token is fresh, so no further refresh.
685                for _ in 0..4 {
686                    let token = strategy.get_token().await.unwrap();
687                    assert_eq!(
688                        token.as_str(),
689                        "refreshed-once",
690                        "should return cached refreshed token, not trigger another refresh"
691                    );
692                }
693            }
694
695            #[tokio::test]
696            async fn prevents_second_refresh_after_success() {
697                let mut mocks = MockSet::new();
698                mocks.mock(|when, then| {
699                    when.post().path("/oauth/token");
700                    then.json(refresh_response_json("refreshed-token"));
701                });
702                let server = start_server(mocks).await;
703                let dir = tempfile::tempdir().unwrap();
704                let strategy =
705                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
706
707                // First call refreshes successfully.
708                let token = strategy.get_token().await.unwrap();
709                assert_eq!(
710                    token.as_str(),
711                    "refreshed-token",
712                    "first call should refresh the token"
713                );
714
715                // Replace the mock with one that errors.
716                server.mocks().clear();
717                server.mocks().mock(|when, then| {
718                    when.post().path("/oauth/token");
719                    then.bad_request().json(error_json("should_not_be_called"));
720                });
721
722                // Second call should return the refreshed token without hitting
723                // the server again (the new token has a fresh expiry).
724                let token = strategy.get_token().await.unwrap();
725                assert_eq!(
726                    token.as_str(),
727                    "refreshed-token",
728                    "second call should return cached refreshed token"
729                );
730            }
731        }
732    }
733
734    mod given_expiring_but_usable_token {
735        use super::*;
736
737        mod when_refresh_fails {
738            use super::*;
739
740            #[tokio::test]
741            async fn returns_current_token() {
742                let mut mocks = MockSet::new();
743                mocks.mock(|when, then| {
744                    when.post().path("/oauth/token");
745                    then.bad_request().json(error_json("server_error"));
746                });
747                let server = start_server(mocks).await;
748                let dir = tempfile::tempdir().unwrap();
749                // Token expires in 30s (within the 90s leeway so is_expired() = true),
750                // but the access token is still technically usable.
751                let strategy =
752                    auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
753
754                // The refresh fails, but the access token should still be returned
755                // because it's still usable (30s remaining > 0).
756                let token = strategy.get_token().await.unwrap();
757                assert_eq!(
758                    token.as_str(),
759                    "still-usable",
760                    "should return still-usable token despite failed refresh"
761                );
762
763                // Verify the access token and refresh token are still present.
764                let state = strategy.state.lock().await;
765                assert!(state.token.is_some(), "token should still be cached");
766                assert_eq!(
767                    state.token.as_ref().unwrap().access_token().as_str(),
768                    "still-usable",
769                    "access token should be unchanged after failed refresh"
770                );
771                assert!(
772                    state.token.as_ref().unwrap().refresh_token().is_some(),
773                    "refresh token should be restored after failed refresh"
774                );
775            }
776
777            #[tokio::test]
778            async fn restores_refresh_token_for_retry() {
779                let mut mocks = MockSet::new();
780                mocks.mock(|when, then| {
781                    when.post().path("/oauth/token");
782                    then.bad_request().json(error_json("server_error"));
783                });
784                let server = start_server(mocks).await;
785                let dir = tempfile::tempdir().unwrap();
786                // Token expires in 30s — is_expired() = true, is_usable() = true.
787                let strategy =
788                    auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
789
790                // First call: refresh fails, but the still-usable token is returned.
791                let token = strategy.get_token().await.unwrap();
792                assert_eq!(
793                    token.as_str(),
794                    "still-usable",
795                    "first call should return still-usable token"
796                );
797
798                // Replace mock with a success response.
799                server.mocks().clear();
800                server.mocks().mock(|when, then| {
801                    when.post().path("/oauth/token");
802                    then.json(refresh_response_json("refreshed-token"));
803                });
804
805                // Second call: refresh token was restored, so the retry succeeds.
806                let token = strategy.get_token().await.unwrap();
807                assert!(
808                    token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
809                    "expected old or refreshed token, got: {}",
810                    token.as_str()
811                );
812
813                // Verify the cache now holds the refreshed token.
814                let state = strategy.state.lock().await;
815                assert_eq!(
816                    state.token.as_ref().unwrap().access_token().as_str(),
817                    "refreshed-token",
818                    "cache should hold the refreshed token after retry"
819                );
820            }
821        }
822    }
823
824    mod given_concurrent_callers {
825        use super::*;
826
827        #[tokio::test]
828        async fn returns_usable_token_while_refreshing() {
829            let mut mocks = MockSet::new();
830            mocks.mock(|when, then| {
831                when.post().path("/oauth/token");
832                then.json(refresh_response_json("refreshed-token"));
833            });
834            let server = start_server(mocks).await;
835            let dir = tempfile::tempdir().unwrap();
836            let strategy = Arc::new(auto_refresh_with_token(
837                &dir,
838                &server,
839                make_token("still-usable", 30, true),
840            ));
841
842            let s1 = Arc::clone(&strategy);
843            let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
844
845            let s2 = Arc::clone(&strategy);
846            let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
847
848            let (result_a, result_b) = tokio::join!(handle_a, handle_b);
849            let token_a = result_a.unwrap();
850            let token_b = result_b.unwrap();
851
852            assert!(
853                token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
854                "unexpected token_a: {}",
855                token_a.as_str()
856            );
857            assert!(
858                token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
859                "unexpected token_b: {}",
860                token_b.as_str()
861            );
862        }
863
864        #[tokio::test]
865        async fn blocks_until_refresh_completes() {
866            let mut mocks = MockSet::new();
867            mocks.mock(|when, then| {
868                when.post().path("/oauth/token");
869                then.json(refresh_response_json("refreshed-token"));
870            });
871            let server = start_server(mocks).await;
872            let dir = tempfile::tempdir().unwrap();
873            let strategy = Arc::new(auto_refresh_with_token(
874                &dir,
875                &server,
876                make_token("expired-token", 0, true),
877            ));
878
879            let s1 = Arc::clone(&strategy);
880            let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
881
882            let s2 = Arc::clone(&strategy);
883            let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
884
885            let (result_a, result_b) = tokio::join!(handle_a, handle_b);
886            let token_a = result_a.unwrap();
887            let token_b = result_b.unwrap();
888
889            assert_eq!(
890                token_a.as_str(),
891                "refreshed-token",
892                "caller a should receive refreshed token"
893            );
894            assert_eq!(
895                token_b.as_str(),
896                "refreshed-token",
897                "caller b should receive refreshed token"
898            );
899        }
900    }
901}
902
903#[cfg(test)]
904#[allow(clippy::unwrap_used)]
905mod stress_tests {
906    use super::*;
907    use crate::oauth_refresher::OAuthRefresher;
908    use crate::SecretToken;
909    use stack_profile::ProfileStore;
910    use std::sync::atomic::{AtomicUsize, Ordering};
911    use std::sync::Arc;
912    use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
913
914    /// Tracks in-flight and peak concurrency for test assertions.
915    #[derive(Clone)]
916    struct CountingState {
917        total: Arc<AtomicUsize>,
918        current: Arc<AtomicUsize>,
919        peak: Arc<AtomicUsize>,
920    }
921
922    impl CountingState {
923        fn new() -> Self {
924            Self {
925                total: Arc::new(AtomicUsize::new(0)),
926                current: Arc::new(AtomicUsize::new(0)),
927                peak: Arc::new(AtomicUsize::new(0)),
928            }
929        }
930
931        fn enter(&self) {
932            self.total.fetch_add(1, Ordering::SeqCst);
933            let prev = self.current.fetch_add(1, Ordering::SeqCst);
934            self.peak.fetch_max(prev + 1, Ordering::SeqCst);
935        }
936
937        fn exit(&self) {
938            self.current.fetch_sub(1, Ordering::SeqCst);
939        }
940
941        fn peak(&self) -> usize {
942            self.peak.load(Ordering::SeqCst)
943        }
944
945        fn total(&self) -> usize {
946            self.total.load(Ordering::SeqCst)
947        }
948    }
949
950    #[derive(Clone)]
951    struct DelayedRefreshState {
952        counting: CountingState,
953        delay: Duration,
954    }
955
956    async fn delayed_refresh_handler(
957        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
958    ) -> axum::Json<serde_json::Value> {
959        state.counting.enter();
960        tokio::time::sleep(state.delay).await;
961        state.counting.exit();
962        axum::Json(serde_json::json!({
963            "access_token": "refreshed-token",
964            "token_type": "Bearer",
965            "expires_in": 3600,
966            "refresh_token": "new-refresh-token"
967        }))
968    }
969
970    async fn delayed_error_handler(
971        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
972    ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
973        state.counting.enter();
974        tokio::time::sleep(state.delay).await;
975        state.counting.exit();
976        (
977            axum::http::StatusCode::BAD_REQUEST,
978            axum::Json(serde_json::json!({
979                "error": "invalid_grant",
980                "error_description": "invalid_grant occurred"
981            })),
982        )
983    }
984
985    async fn start_axum_server<H, T>(
986        handler: H,
987        state: DelayedRefreshState,
988    ) -> (url::Url, CountingState)
989    where
990        H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
991        T: 'static,
992    {
993        let counting = state.counting.clone();
994        let app = axum::Router::new()
995            .route("/oauth/token", axum::routing::post(handler))
996            .with_state(state);
997        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
998        let addr = listener.local_addr().unwrap();
999        tokio::spawn(async move {
1000            axum::serve(listener, app).await.unwrap();
1001        });
1002        let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
1003        (base_url, counting)
1004    }
1005
1006    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
1007        let now = SystemTime::now()
1008            .duration_since(UNIX_EPOCH)
1009            .unwrap()
1010            .as_secs();
1011
1012        Token {
1013            access_token: SecretToken::new(access),
1014            token_type: "Bearer".to_string(),
1015            expires_at: now + expires_in,
1016            refresh_token: if refresh {
1017                Some(SecretToken::new("test-refresh-token"))
1018            } else {
1019                None
1020            },
1021            region: None,
1022            client_id: None,
1023            device_instance_id: None,
1024        }
1025    }
1026
1027    fn auto_refresh_with_token(
1028        dir: &tempfile::TempDir,
1029        base_url: &url::Url,
1030        token: Token,
1031    ) -> AutoRefresh<OAuthRefresher> {
1032        let store = ProfileStore::new(dir.path());
1033        store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
1034        let ws_store = store.current_workspace_store().unwrap();
1035        ws_store.save_profile(&token).unwrap();
1036        let refresher = OAuthRefresher::new(
1037            Some(ws_store),
1038            base_url.clone(),
1039            "cli",
1040            "ap-southeast-2.aws",
1041            None,
1042        );
1043        AutoRefresh::with_token(refresher, token)
1044    }
1045
1046    const CONCURRENCY: usize = 50;
1047
1048    mod given_fresh_token {
1049        use super::*;
1050
1051        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1052        async fn all_callers_return_immediately() {
1053            let counting = CountingState::new();
1054            let state = DelayedRefreshState {
1055                counting: counting.clone(),
1056                delay: Duration::from_millis(500),
1057            };
1058            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1059            let dir = tempfile::tempdir().unwrap();
1060            let strategy = Arc::new(auto_refresh_with_token(
1061                &dir,
1062                &base_url,
1063                make_token("fresh-token", 3600, true),
1064            ));
1065
1066            let start = Instant::now();
1067            let mut handles = Vec::with_capacity(CONCURRENCY);
1068            for _ in 0..CONCURRENCY {
1069                let s = Arc::clone(&strategy);
1070                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1071            }
1072
1073            let results: Vec<_> = {
1074                let mut results = Vec::with_capacity(handles.len());
1075                for handle in handles {
1076                    results.push(handle.await.unwrap());
1077                }
1078                results
1079            };
1080            let elapsed = start.elapsed();
1081
1082            for token in &results {
1083                assert_eq!(
1084                    token.as_str(),
1085                    "fresh-token",
1086                    "all callers should receive the fresh token"
1087                );
1088            }
1089
1090            assert!(
1091                elapsed < Duration::from_millis(200),
1092                "expected < 200ms for fresh tokens, got {:?}",
1093                elapsed
1094            );
1095            assert_eq!(stats.total(), 0, "no refresh requests should be made");
1096        }
1097    }
1098
1099    mod given_expiring_but_usable_token {
1100        use super::*;
1101
1102        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1103        async fn non_blocking_reads_during_refresh() {
1104            let counting = CountingState::new();
1105            let state = DelayedRefreshState {
1106                counting: counting.clone(),
1107                delay: Duration::from_millis(500),
1108            };
1109            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1110            let dir = tempfile::tempdir().unwrap();
1111            let strategy = Arc::new(auto_refresh_with_token(
1112                &dir,
1113                &base_url,
1114                make_token("still-usable", 30, true),
1115            ));
1116
1117            let start = Instant::now();
1118            let mut handles = Vec::with_capacity(CONCURRENCY);
1119            for _ in 0..CONCURRENCY {
1120                let s = Arc::clone(&strategy);
1121                handles.push(tokio::spawn(async move {
1122                    let call_start = Instant::now();
1123                    let token = s.get_token().await.unwrap();
1124                    (token, call_start.elapsed())
1125                }));
1126            }
1127
1128            let results: Vec<_> = {
1129                let mut results = Vec::with_capacity(handles.len());
1130                for handle in handles {
1131                    results.push(handle.await.unwrap());
1132                }
1133                results
1134            };
1135            let elapsed = start.elapsed();
1136
1137            for (token, _) in &results {
1138                assert!(
1139                    token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1140                    "unexpected token: {}",
1141                    token.as_str()
1142                );
1143            }
1144
1145            let fast_callers = results
1146                .iter()
1147                .filter(|(_, dur)| *dur < Duration::from_millis(100))
1148                .count();
1149            assert!(
1150                fast_callers >= CONCURRENCY - 1,
1151                "expected at least {} fast callers, got {} (total elapsed: {:?})",
1152                CONCURRENCY - 1,
1153                fast_callers,
1154                elapsed
1155            );
1156
1157            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1158            assert_eq!(stats.total(), 1, "total refresh requests");
1159        }
1160
1161        /// Reproduces the race condition where a token crosses real expiry during
1162        /// an in-flight non-blocking refresh. Before the fix, late-arriving callers
1163        /// would see `refresh_in_progress = true` + `!is_usable()` and return
1164        /// `Err(Expired)` instead of waiting for the refresh to complete.
1165        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1166        async fn waiters_receive_token_when_expiry_crosses() {
1167            // Token with 1s until real expiry (minimum granularity since
1168            // expires_at is in seconds). is_expired() = true (within 90s leeway),
1169            // is_usable() = true (1s remaining). Refresh takes 1.5s so the token
1170            // crosses real expiry mid-refresh.
1171            let refresh_delay = Duration::from_millis(1500);
1172            let counting = CountingState::new();
1173            let state = DelayedRefreshState {
1174                counting: counting.clone(),
1175                delay: refresh_delay,
1176            };
1177            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1178            let dir = tempfile::tempdir().unwrap();
1179            let strategy = Arc::new(auto_refresh_with_token(
1180                &dir,
1181                &base_url,
1182                make_token("expiring-soon", 1, true),
1183            ));
1184
1185            // First caller triggers the non-blocking refresh and gets the old token.
1186            let first = strategy.get_token().await.unwrap();
1187            assert_eq!(
1188                first.as_str(),
1189                "expiring-soon",
1190                "first caller should receive the expiring token"
1191            );
1192
1193            // Wait for the token to cross real expiry (but refresh is still in-flight).
1194            tokio::time::sleep(Duration::from_millis(1100)).await;
1195
1196            // Launch 50 concurrent callers. Without the fix, these would all get
1197            // Err(Expired) because refresh_in_progress = true and !is_usable().
1198            let mut handles = Vec::with_capacity(CONCURRENCY);
1199            for _ in 0..CONCURRENCY {
1200                let s = Arc::clone(&strategy);
1201                handles.push(tokio::spawn(async move { s.get_token().await }));
1202            }
1203
1204            let results: Vec<_> = {
1205                let mut results = Vec::with_capacity(handles.len());
1206                for handle in handles {
1207                    results.push(handle.await.unwrap());
1208                }
1209                results
1210            };
1211
1212            // All callers must succeed — none should get Expired.
1213            for (i, result) in results.iter().enumerate() {
1214                assert!(
1215                    result.is_ok(),
1216                    "caller {i} got Err({:?}), expected Ok",
1217                    result.as_ref().unwrap_err()
1218                );
1219                assert_eq!(
1220                    result.as_ref().unwrap().as_str(),
1221                    "refreshed-token",
1222                    "caller {i} should receive the refreshed token"
1223                );
1224            }
1225
1226            assert_eq!(stats.total(), 1, "only one refresh request should be made");
1227        }
1228    }
1229
1230    mod given_fully_expired_token {
1231        use super::*;
1232
1233        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1234        async fn all_callers_block_until_refresh() {
1235            let refresh_delay = Duration::from_millis(200);
1236            let counting = CountingState::new();
1237            let state = DelayedRefreshState {
1238                counting: counting.clone(),
1239                delay: refresh_delay,
1240            };
1241            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1242            let dir = tempfile::tempdir().unwrap();
1243            let strategy = Arc::new(auto_refresh_with_token(
1244                &dir,
1245                &base_url,
1246                make_token("expired-token", 0, true),
1247            ));
1248
1249            let start = Instant::now();
1250            let mut handles = Vec::with_capacity(CONCURRENCY);
1251            for _ in 0..CONCURRENCY {
1252                let s = Arc::clone(&strategy);
1253                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1254            }
1255
1256            let results: Vec<_> = {
1257                let mut results = Vec::with_capacity(handles.len());
1258                for handle in handles {
1259                    results.push(handle.await.unwrap());
1260                }
1261                results
1262            };
1263            let elapsed = start.elapsed();
1264
1265            for token in &results {
1266                assert_eq!(
1267                    token.as_str(),
1268                    "refreshed-token",
1269                    "all callers should receive refreshed token"
1270                );
1271            }
1272
1273            assert!(
1274                elapsed < refresh_delay + Duration::from_millis(200),
1275                "expected < {:?} for blocked callers, got {:?}",
1276                refresh_delay + Duration::from_millis(200),
1277                elapsed
1278            );
1279
1280            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1281            assert_eq!(stats.total(), 1, "total refresh requests");
1282        }
1283
1284        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1285        async fn all_callers_receive_expired_on_failure() {
1286            let counting = CountingState::new();
1287            let state = DelayedRefreshState {
1288                counting: counting.clone(),
1289                delay: Duration::from_millis(10),
1290            };
1291            let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1292            let dir = tempfile::tempdir().unwrap();
1293            let strategy = Arc::new(auto_refresh_with_token(
1294                &dir,
1295                &base_url,
1296                make_token("expired-token", 0, true),
1297            ));
1298
1299            let mut handles = Vec::with_capacity(CONCURRENCY);
1300            for _ in 0..CONCURRENCY {
1301                let s = Arc::clone(&strategy);
1302                handles.push(tokio::spawn(async move { s.get_token().await }));
1303            }
1304
1305            let results: Vec<_> = {
1306                let mut results = Vec::with_capacity(handles.len());
1307                for handle in handles {
1308                    results.push(handle.await.unwrap());
1309                }
1310                results
1311            };
1312
1313            for result in &results {
1314                assert!(result.is_err(), "expected Expired error, got Ok");
1315                let err = result.as_ref().unwrap_err();
1316                assert!(
1317                    matches!(err, AutoRefreshError::Expired),
1318                    "expected Expired, got: {err:?}"
1319                );
1320            }
1321
1322            let state = strategy.state.lock().await;
1323            assert!(
1324                state.token.as_ref().unwrap().refresh_token().is_some(),
1325                "refresh token should be restored after failed refresh"
1326            );
1327            drop(state);
1328
1329            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1330            assert!(
1331                stats.total() >= 1,
1332                "at least one refresh attempt should be made"
1333            );
1334        }
1335
1336        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1337        async fn retry_succeeds_after_failure() {
1338            // Phase 1: Server returns errors.
1339            let counting1 = CountingState::new();
1340            let state1 = DelayedRefreshState {
1341                counting: counting1.clone(),
1342                delay: Duration::from_millis(50),
1343            };
1344            let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1345            let dir = tempfile::tempdir().unwrap();
1346            let strategy = Arc::new(auto_refresh_with_token(
1347                &dir,
1348                &base_url,
1349                make_token("expired-token", 0, true),
1350            ));
1351
1352            let mut handles = Vec::with_capacity(CONCURRENCY);
1353            for _ in 0..CONCURRENCY {
1354                let s = Arc::clone(&strategy);
1355                handles.push(tokio::spawn(async move { s.get_token().await }));
1356            }
1357
1358            let results: Vec<_> = {
1359                let mut results = Vec::with_capacity(handles.len());
1360                for handle in handles {
1361                    results.push(handle.await.unwrap());
1362                }
1363                results
1364            };
1365
1366            for result in &results {
1367                assert!(
1368                    result.is_err(),
1369                    "first wave: expected Expired, got Ok({})",
1370                    result.as_ref().unwrap().as_str()
1371                );
1372            }
1373
1374            // Phase 2: New server that returns success.
1375            let counting2 = CountingState::new();
1376            let state2 = DelayedRefreshState {
1377                counting: counting2.clone(),
1378                delay: Duration::from_millis(50),
1379            };
1380            let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1381
1382            let strategy2 = Arc::new(auto_refresh_with_token(
1383                &dir,
1384                &base_url2,
1385                make_token("expired-token", 0, true),
1386            ));
1387
1388            let mut handles = Vec::with_capacity(CONCURRENCY);
1389            for _ in 0..CONCURRENCY {
1390                let s = Arc::clone(&strategy2);
1391                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1392            }
1393
1394            let results: Vec<_> = {
1395                let mut results = Vec::with_capacity(handles.len());
1396                for handle in handles {
1397                    results.push(handle.await.unwrap());
1398                }
1399                results
1400            };
1401
1402            for token in &results {
1403                assert_eq!(
1404                    token.as_str(),
1405                    "refreshed-token",
1406                    "retry callers should receive refreshed token"
1407                );
1408            }
1409
1410            assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1411        }
1412    }
1413
1414    mod given_cancelled_refresh {
1415        use super::*;
1416
1417        /// If a blocking refresh (fully expired token) is cancelled mid-flight,
1418        /// the `CancelGuard` must reset `refresh_in_progress` and notify waiters
1419        /// so the next caller doesn't hang in `wait_for_in_flight_refresh`.
1420        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1421        async fn blocked_callers_recover_after_cancellation() {
1422            let counting = CountingState::new();
1423            let state = DelayedRefreshState {
1424                counting: counting.clone(),
1425                delay: Duration::from_secs(10), // Very slow — will be cancelled
1426            };
1427            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1428            let dir = tempfile::tempdir().unwrap();
1429            let strategy = Arc::new(auto_refresh_with_token(
1430                &dir,
1431                &base_url,
1432                make_token("expired-token", 0, true),
1433            ));
1434
1435            // Spawn get_token and let the blocking refresh start.
1436            let s = Arc::clone(&strategy);
1437            let handle = tokio::spawn(async move { s.get_token().await });
1438            tokio::time::sleep(Duration::from_millis(100)).await;
1439
1440            // Cancel the refresh mid-flight.
1441            handle.abort();
1442            let _ = handle.await;
1443
1444            // The next caller must not hang. The credential is lost (refresh
1445            // token was taken before the HTTP call), so the result is Expired,
1446            // but the important thing is that it completes promptly.
1447            let s = Arc::clone(&strategy);
1448            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1449
1450            assert!(
1451                result.is_ok(),
1452                "get_token() should not hang after cancelled blocking refresh"
1453            );
1454        }
1455
1456        /// Regression test: cancellation in the window *after* the upstream
1457        /// HTTP refresh succeeds but *before* the new token is installed must
1458        /// still clear `refresh_in_progress` and notify waiters. The previous
1459        /// implementation defused the [`CancelGuard`] before
1460        /// `save_refreshed_token`, so a drop during the (async) store-save or
1461        /// the subsequent state-lock acquire would strand the flag — wedging
1462        /// any caller that later hit `wait_for_in_flight_refresh`.
1463        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1464        async fn save_phase_cancellation_does_not_strand_in_progress_flag() {
1465            use crate::token_store::TokenStore;
1466
1467            /// Store that returns a single pre-loaded token from `load()` and
1468            /// delays inside `save()` long enough for a test to cancel.
1469            struct SlowSaveStore {
1470                initial: tokio::sync::Mutex<Option<Token>>,
1471                delay: Duration,
1472            }
1473
1474            impl TokenStore for SlowSaveStore {
1475                async fn load(&self) -> Option<Token> {
1476                    self.initial.lock().await.take()
1477                }
1478
1479                async fn save(&self, _token: &Token) {
1480                    tokio::time::sleep(self.delay).await;
1481                }
1482            }
1483
1484            // Fast upstream HTTP — refresh succeeds in <50ms.
1485            let counting = CountingState::new();
1486            let state = DelayedRefreshState {
1487                counting: counting.clone(),
1488                delay: Duration::from_millis(10),
1489            };
1490            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1491            let dir = tempfile::tempdir().unwrap();
1492            let store = ProfileStore::new(dir.path());
1493            store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
1494            let ws_store = store.current_workspace_store().unwrap();
1495            let refresher =
1496                OAuthRefresher::new(Some(ws_store), base_url, "cli", "ap-southeast-2.aws", None);
1497            // Slow async save — cancellation reliably lands here, in the
1498            // post-HTTP / pre-install window.
1499            let slow_store = SlowSaveStore {
1500                initial: tokio::sync::Mutex::new(Some(make_token("expired-token", 0, true))),
1501                delay: Duration::from_secs(10),
1502            };
1503            let strategy = Arc::new(AutoRefresh::with_store(refresher, slow_store));
1504
1505            // Trigger refresh; the task will complete the HTTP exchange and
1506            // then block inside store.save (the slow async path).
1507            let s = Arc::clone(&strategy);
1508            let handle = tokio::spawn(async move { s.get_token().await });
1509            // 200ms is comfortably past the 10ms HTTP delay but well inside
1510            // the 10s save delay — so abort() lands during save_refreshed_token.
1511            tokio::time::sleep(Duration::from_millis(200)).await;
1512            handle.abort();
1513            let _ = handle.await;
1514
1515            // The CancelGuard must have cleared refresh_in_progress on drop.
1516            // If the old (pre-fix) code regresses, the flag stays true and a
1517            // subsequent caller wedges on wait_for_in_flight_refresh waiting
1518            // for a notify that will never come — the timeout below catches it.
1519            let s = Arc::clone(&strategy);
1520            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1521
1522            assert!(
1523                result.is_ok(),
1524                "get_token() should not hang after cancellation in the save/install window"
1525            );
1526        }
1527
1528        /// If a non-blocking refresh (expiring-but-usable token) is cancelled
1529        /// mid-flight, the `CancelGuard` must reset `refresh_in_progress` and
1530        /// notify waiters so they don't hang once the token crosses real expiry.
1531        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1532        async fn non_blocking_callers_recover_after_cancellation() {
1533            let counting = CountingState::new();
1534            let state = DelayedRefreshState {
1535                counting: counting.clone(),
1536                delay: Duration::from_secs(10), // Very slow — will be cancelled
1537            };
1538            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1539            let dir = tempfile::tempdir().unwrap();
1540            // Token expires in 30s — is_expired() = true, is_usable() = true.
1541            let strategy = Arc::new(auto_refresh_with_token(
1542                &dir,
1543                &base_url,
1544                make_token("still-usable", 30, true),
1545            ));
1546
1547            // Spawn get_token — triggers non-blocking refresh, drops lock, then
1548            // blocks on the slow HTTP call.
1549            let s = Arc::clone(&strategy);
1550            let handle = tokio::spawn(async move { s.get_token().await });
1551            tokio::time::sleep(Duration::from_millis(100)).await;
1552
1553            // Cancel the refresh mid-flight.
1554            handle.abort();
1555            let _ = handle.await;
1556
1557            // The next caller must not hang. The token is still usable so it
1558            // should be returned even though the refresh was cancelled.
1559            let s = Arc::clone(&strategy);
1560            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1561
1562            assert!(
1563                result.is_ok(),
1564                "get_token() should not hang after cancelled non-blocking refresh"
1565            );
1566            let result = result.unwrap();
1567            assert!(
1568                result.is_ok(),
1569                "expected Ok with still-usable token, got: {:?}",
1570                result.unwrap_err()
1571            );
1572        }
1573    }
1574}