Skip to main content

stack_auth/
auto_refresh.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::token_store::{NoStore, TokenStore};
7use crate::{ServiceToken, Token};
8
9/// Internal errors from [`AutoRefresh::get_token`].
10///
11/// Strategy wrappers convert these into [`AuthError`](crate::AuthError) for the
12/// public API.
13#[derive(Debug, thiserror::Error)]
14pub(crate) enum AutoRefreshError {
15    /// No token is cached and the strategy cannot self-authenticate.
16    #[error("No token found")]
17    NotFound,
18    /// The token has expired and refresh failed or is unavailable.
19    #[error("Token has expired")]
20    Expired,
21    /// The refresh/auth HTTP call failed.
22    #[error("Auth error: {0}")]
23    Auth(#[from] crate::AuthError),
24}
25
26impl From<AutoRefreshError> for crate::AuthError {
27    fn from(err: AutoRefreshError) -> Self {
28        match err {
29            AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
30            AutoRefreshError::Expired => crate::AuthError::TokenExpired,
31            AutoRefreshError::Auth(e) => e,
32        }
33    }
34}
35
36/// Caches a token in memory and uses a [`Refresher`] to re-authenticate
37/// or refresh before expiry, optionally backed by an external [`TokenStore`]
38/// for persistence across short-lived strategy instances.
39///
40/// See the [crate-level documentation](crate#token-refresh) for a full
41/// description of the concurrency model and flow diagram.
42pub(crate) struct AutoRefresh<R, S = NoStore> {
43    refresher: R,
44    state: Mutex<State>,
45    store: S,
46    /// Set to `true` while a refresh HTTP call is in-flight.
47    ///
48    /// Stored as an [`AtomicBool`] rather than inside [`State`] so that
49    /// [`CancelGuard`] can reset it on future cancellation without acquiring
50    /// the mutex.
51    refresh_in_progress: AtomicBool,
52    refresh_notify: Notify,
53}
54
55struct State {
56    token: Option<Token>,
57}
58
59/// Ensures [`AutoRefresh::refresh_in_progress`] is cleared and waiters are
60/// notified if the refresh future is cancelled (dropped) before completing.
61///
62/// On the normal path (success or handled error), the guard is defused before
63/// drop so that the regular cleanup code runs instead.
64struct CancelGuard<'a> {
65    in_progress: &'a AtomicBool,
66    notify: &'a Notify,
67    defused: bool,
68}
69
70impl Drop for CancelGuard<'_> {
71    fn drop(&mut self) {
72        if !self.defused {
73            self.in_progress.store(false, Ordering::Release);
74            self.notify.notify_waiters();
75        }
76    }
77}
78
79impl CancelGuard<'_> {
80    fn defuse(&mut self) {
81        self.defused = true;
82    }
83}
84
85impl State {
86    fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
87        let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
88        Ok(ServiceToken::new(token.access_token().clone()))
89    }
90
91    fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
92        let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
93        if token.is_usable() {
94            Ok(ServiceToken::new(token.access_token().clone()))
95        } else {
96            Err(AutoRefreshError::Expired)
97        }
98    }
99}
100
101impl<R> AutoRefresh<R, NoStore> {
102    /// Create a new `AutoRefresh` with a pre-loaded token and no external store.
103    ///
104    /// Use this for refreshers that cannot self-authenticate (e.g. OAuth,
105    /// which needs a refresh token from a prior device code flow).
106    pub(crate) fn with_token(refresher: R, token: Token) -> Self {
107        Self {
108            refresher,
109            state: Mutex::new(State { token: Some(token) }),
110            store: NoStore,
111            refresh_in_progress: AtomicBool::new(false),
112            refresh_notify: Notify::new(),
113        }
114    }
115}
116
117impl<R, S: TokenStore> AutoRefresh<R, S> {
118    /// Create a new `AutoRefresh` backed by `store` and no in-memory token.
119    ///
120    /// On the first `get_token` call the store is consulted before falling
121    /// through to initial authentication via `try_credential(None)`; every
122    /// successful refresh writes the new token back via `store.save()`. Pass
123    /// [`NoStore`] for the no-external-cache case — it's the default and
124    /// elides to a zero-cost no-op.
125    pub(crate) fn with_store(refresher: R, store: S) -> Self {
126        Self {
127            refresher,
128            state: Mutex::new(State { token: None }),
129            store,
130            refresh_in_progress: AtomicBool::new(false),
131            refresh_notify: Notify::new(),
132        }
133    }
134}
135
136impl<R: Refresher, S: TokenStore> AutoRefresh<R, S> {
137    /// Retrieve a valid access token, refreshing or re-authenticating as needed.
138    pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
139        let mut state = self.state.lock().await;
140
141        if state.token.is_none() {
142            // Drop the lock for the store read so a slow user-supplied backend
143            // (cookie, KV, Redis) doesn't serialise concurrent `get_token`
144            // callers. Re-acquire and double-check `state.token.is_none()` in
145            // case another caller populated it while we awaited.
146            drop(state);
147            let loaded = self.store.load().await;
148            state = self.state.lock().await;
149            if state.token.is_none() {
150                state.token = loaded;
151            }
152        }
153
154        if state.token.is_none() {
155            return self.initial_auth(&mut state).await;
156        }
157
158        if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
159            return state.service_token();
160        }
161
162        if self.refresh_in_progress.load(Ordering::Acquire) {
163            return self.wait_for_in_flight_refresh(state).await;
164        }
165
166        let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
167            return state.require_usable_token();
168        };
169
170        self.refresh_in_progress.store(true, Ordering::Release);
171
172        if state.token.as_ref().is_some_and(|t| t.is_usable()) {
173            self.refresh_non_blocking(state, credential).await
174        } else {
175            self.refresh_blocking(&mut state, credential).await
176        }
177    }
178
179    /// No cached token — authenticate via `try_credential(None)`.
180    ///
181    /// The lock is held throughout to prevent concurrent initial-auth attempts.
182    async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
183        let Some(credential) = self.refresher.try_credential(None) else {
184            return Err(AutoRefreshError::NotFound);
185        };
186        self.refresh_in_progress.store(true, Ordering::Release);
187        let mut guard = CancelGuard {
188            in_progress: &self.refresh_in_progress,
189            notify: &self.refresh_notify,
190            defused: false,
191        };
192        match self.refresher.refresh(&credential).await {
193            Ok(new_token) => {
194                guard.defuse();
195                self.save_refreshed_token(&new_token).await;
196                Ok(self.install_refreshed_token(state, new_token))
197            }
198            Err(err) => {
199                guard.defuse();
200                self.refresh_in_progress.store(false, Ordering::Release);
201                Err(AutoRefreshError::Auth(err))
202            }
203        }
204    }
205
206    /// Persist a freshly refreshed token to the per-refresher sink and the
207    /// user-supplied `TokenStore`. Awaits the store write, so callers should
208    /// drop the state lock before invoking this where possible (the
209    /// non-blocking refresh path does; the blocking/initial paths hold the
210    /// state lock throughout by design).
211    async fn save_refreshed_token(&self, new_token: &Token) {
212        self.refresher.save(new_token);
213        self.store.save(new_token).await;
214    }
215
216    /// Install a freshly refreshed token in `state`, clear the in-progress
217    /// flag, and return the corresponding [`ServiceToken`]. Pure in-lock
218    /// work; caller is responsible for having already persisted the token via
219    /// [`save_refreshed_token`](Self::save_refreshed_token).
220    fn install_refreshed_token(&self, state: &mut State, new_token: Token) -> ServiceToken {
221        let service_token = ServiceToken::new(new_token.access_token().clone());
222        state.token = Some(new_token);
223        self.refresh_in_progress.store(false, Ordering::Release);
224        service_token
225    }
226
227    /// Another caller is already refreshing — return the current token if still
228    /// usable, otherwise wait for the in-flight refresh to complete via `Notify`.
229    ///
230    /// Takes `MutexGuard` by value because the lock is dropped before awaiting
231    /// the notification.
232    async fn wait_for_in_flight_refresh(
233        &self,
234        state: MutexGuard<'_, State>,
235    ) -> Result<ServiceToken, AutoRefreshError> {
236        if let Ok(token) = state.service_token() {
237            if state.token.as_ref().is_some_and(|t| t.is_usable()) {
238                return Ok(token);
239            }
240        }
241        // Token crossed real expiry during in-flight refresh. Wait for the
242        // refresh to complete rather than returning Expired.
243        let notified = self.refresh_notify.notified();
244        drop(state);
245        notified.await;
246        // Re-check after wake — refresh may have failed.
247        let state = self.state.lock().await;
248        state.require_usable_token()
249    }
250
251    /// Token is expiring but still usable — drop the lock, refresh in the
252    /// background of this call, and return the old (still-valid) token.
253    ///
254    /// Takes `MutexGuard` by value because the lock is dropped before the HTTP
255    /// request. Notifies waiters after the refresh completes (success or error).
256    ///
257    /// A [`CancelGuard`] ensures that if this future is cancelled during the
258    /// HTTP request, `refresh_in_progress` is cleared, the credential is
259    /// restored (best-effort via `try_lock`), and waiters are notified.
260    async fn refresh_non_blocking(
261        &self,
262        state: MutexGuard<'_, State>,
263        credential: R::Credential,
264    ) -> Result<ServiceToken, AutoRefreshError> {
265        let current_service_token = state.service_token()?;
266        drop(state);
267
268        let mut guard = CancelGuard {
269            in_progress: &self.refresh_in_progress,
270            notify: &self.refresh_notify,
271            defused: false,
272        };
273
274        match self.refresher.refresh(&credential).await {
275            Ok(new_token) => {
276                guard.defuse();
277                self.save_refreshed_token(&new_token).await;
278                let mut state = self.state.lock().await;
279                let _ = self.install_refreshed_token(&mut state, new_token);
280            }
281            Err(err) => {
282                guard.defuse();
283                tracing::warn!(%err, "token refresh failed (token still usable)");
284                let mut state = self.state.lock().await;
285                if let Some(token) = state.token.as_mut() {
286                    self.refresher.restore(token, credential);
287                }
288                self.refresh_in_progress.store(false, Ordering::Release);
289            }
290        }
291
292        self.refresh_notify.notify_waiters();
293        Ok(current_service_token)
294    }
295
296    /// Token is fully expired — refresh while holding the lock so concurrent
297    /// callers block on `lock().await` until the new token is available.
298    ///
299    /// A [`CancelGuard`] ensures that if this future is cancelled during the
300    /// HTTP request, `refresh_in_progress` is cleared and waiters are notified
301    /// so they don't hang indefinitely. (The credential is lost on cancel —
302    /// see [`CancelGuard`] docs — but subsequent callers will get `Expired`
303    /// rather than blocking forever.)
304    async fn refresh_blocking(
305        &self,
306        state: &mut State,
307        credential: R::Credential,
308    ) -> Result<ServiceToken, AutoRefreshError> {
309        let mut guard = CancelGuard {
310            in_progress: &self.refresh_in_progress,
311            notify: &self.refresh_notify,
312            defused: false,
313        };
314        match self.refresher.refresh(&credential).await {
315            Ok(new_token) => {
316                guard.defuse();
317                self.save_refreshed_token(&new_token).await;
318                Ok(self.install_refreshed_token(state, new_token))
319            }
320            Err(err) => {
321                guard.defuse();
322                tracing::warn!(%err, "token refresh failed");
323                if let Some(token) = state.token.as_mut() {
324                    self.refresher.restore(token, credential);
325                }
326                self.refresh_in_progress.store(false, Ordering::Release);
327                Err(AutoRefreshError::Expired)
328            }
329        }
330    }
331}
332
333#[cfg(test)]
334#[allow(clippy::unwrap_used)]
335mod tests {
336    use super::*;
337    use crate::oauth_refresher::OAuthRefresher;
338    use crate::SecretToken;
339    use mocktail::prelude::*;
340    use stack_profile::ProfileStore;
341    use std::sync::Arc;
342    use std::time::{SystemTime, UNIX_EPOCH};
343
344    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
345        let now = SystemTime::now()
346            .duration_since(UNIX_EPOCH)
347            .unwrap()
348            .as_secs();
349
350        Token {
351            access_token: SecretToken::new(access),
352            token_type: "Bearer".to_string(),
353            expires_at: now + expires_in,
354            refresh_token: if refresh {
355                Some(SecretToken::new("test-refresh-token"))
356            } else {
357                None
358            },
359            region: None,
360            client_id: None,
361            device_instance_id: None,
362        }
363    }
364
365    fn refresh_response_json(access: &str) -> serde_json::Value {
366        serde_json::json!({
367            "access_token": access,
368            "token_type": "Bearer",
369            "expires_in": 3600,
370            "refresh_token": "new-refresh-token"
371        })
372    }
373
374    fn error_json(error: &str) -> serde_json::Value {
375        serde_json::json!({
376            "error": error,
377            "error_description": format!("{error} occurred")
378        })
379    }
380
381    async fn start_server(mocks: MockSet) -> MockServer {
382        let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
383        server.start().await.unwrap();
384        server
385    }
386
387    fn auto_refresh_with_token(
388        dir: &tempfile::TempDir,
389        server: &MockServer,
390        token: Token,
391    ) -> AutoRefresh<OAuthRefresher> {
392        let store = ProfileStore::new(dir.path());
393        store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
394        let ws_store = store.current_workspace_store().unwrap();
395        ws_store.save_profile(&token).unwrap();
396        let refresher = OAuthRefresher::new(
397            Some(ws_store),
398            server.url(""),
399            "cli",
400            "ap-southeast-2.aws",
401            None,
402        );
403        AutoRefresh::with_token(refresher, token)
404    }
405
406    mod given_no_cached_token {
407        use super::*;
408
409        #[tokio::test]
410        async fn returns_not_found_for_oauth() {
411            let server = start_server(MockSet::new()).await;
412            let store = ProfileStore::new("/tmp/nonexistent");
413            let refresher = OAuthRefresher::new(
414                Some(store),
415                server.url(""),
416                "cli",
417                "ap-southeast-2.aws",
418                None,
419            );
420            let strategy = AutoRefresh::with_store(refresher, NoStore);
421
422            let err = strategy.get_token().await.unwrap_err();
423
424            assert!(
425                matches!(err, AutoRefreshError::NotFound),
426                "expected NotFound, got: {err:?}"
427            );
428        }
429    }
430
431    mod given_fresh_token {
432        use super::*;
433
434        #[tokio::test]
435        async fn returns_cached_token() {
436            let dir = tempfile::tempdir().unwrap();
437            let server = start_server(MockSet::new()).await;
438            let strategy =
439                auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
440
441            let token = strategy.get_token().await.unwrap();
442
443            assert_eq!(
444                token.as_str(),
445                "my-access-token",
446                "should return the cached access token"
447            );
448        }
449
450        #[tokio::test]
451        async fn caches_across_calls() {
452            let dir = tempfile::tempdir().unwrap();
453            let server = start_server(MockSet::new()).await;
454            let strategy =
455                auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
456
457            let token1 = strategy.get_token().await.unwrap();
458            assert_eq!(
459                token1.as_str(),
460                "my-access-token",
461                "first call should return the cached token"
462            );
463
464            // Delete the file — second call should still return the cached token.
465            std::fs::remove_file(
466                dir.path()
467                    .join("workspaces")
468                    .join("ZVATKW3VHMFG27DY")
469                    .join("auth.json"),
470            )
471            .unwrap();
472
473            let token2 = strategy.get_token().await.unwrap();
474            assert_eq!(
475                token2.as_str(),
476                "my-access-token",
477                "second call should return the cached token even after file deletion"
478            );
479        }
480
481        #[tokio::test]
482        async fn does_not_trigger_refresh() {
483            // Mock that would fail if hit — proves no refresh request is made.
484            let mut mocks = MockSet::new();
485            mocks.mock(|when, then| {
486                when.post().path("/oauth/token");
487                then.internal_server_error()
488                    .json(error_json("should_not_be_called"));
489            });
490            let server = start_server(mocks).await;
491            let dir = tempfile::tempdir().unwrap();
492            let strategy =
493                auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
494
495            let token = strategy.get_token().await.unwrap();
496
497            assert_eq!(
498                token.as_str(),
499                "fresh-token",
500                "should return fresh token without triggering refresh"
501            );
502        }
503    }
504
505    mod given_fully_expired_token {
506        use super::*;
507
508        mod without_refresh_token {
509            use super::*;
510
511            #[tokio::test]
512            async fn returns_expired() {
513                let dir = tempfile::tempdir().unwrap();
514                let server = start_server(MockSet::new()).await;
515                let strategy =
516                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
517
518                let err = strategy.get_token().await.unwrap_err();
519
520                assert!(
521                    matches!(err, AutoRefreshError::Expired),
522                    "expected Expired, got: {err:?}"
523                );
524            }
525        }
526
527        mod with_refresh_token {
528            use super::*;
529
530            #[tokio::test]
531            async fn refreshes_and_returns_new_token() {
532                let mut mocks = MockSet::new();
533                mocks.mock(|when, then| {
534                    when.post().path("/oauth/token");
535                    then.json(refresh_response_json("refreshed-token"));
536                });
537                let server = start_server(mocks).await;
538                let dir = tempfile::tempdir().unwrap();
539                let strategy =
540                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
541
542                let token = strategy.get_token().await.unwrap();
543
544                assert_eq!(
545                    token.as_str(),
546                    "refreshed-token",
547                    "should return the refreshed token"
548                );
549            }
550
551            #[tokio::test]
552            async fn persists_refreshed_token_to_disk() {
553                let mut mocks = MockSet::new();
554                mocks.mock(|when, then| {
555                    when.post().path("/oauth/token");
556                    then.json(refresh_response_json("refreshed-token"));
557                });
558                let server = start_server(mocks).await;
559                let dir = tempfile::tempdir().unwrap();
560                let strategy =
561                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
562
563                let _ = strategy.get_token().await.unwrap();
564
565                // Verify the refreshed token was saved to the workspace directory.
566                let store = ProfileStore::new(dir.path());
567                let ws_store = store.current_workspace_store().unwrap();
568                let on_disk: Token = ws_store.load_profile().unwrap();
569                assert_eq!(
570                    on_disk.access_token().as_str(),
571                    "refreshed-token",
572                    "refreshed token should be persisted to disk"
573                );
574            }
575
576            #[tokio::test]
577            async fn returns_expired_on_refresh_failure() {
578                let mut mocks = MockSet::new();
579                mocks.mock(|when, then| {
580                    when.post().path("/oauth/token");
581                    then.bad_request().json(error_json("invalid_grant"));
582                });
583                let server = start_server(mocks).await;
584                let dir = tempfile::tempdir().unwrap();
585                let strategy =
586                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
587
588                let err = strategy.get_token().await.unwrap_err();
589
590                assert!(
591                    matches!(err, AutoRefreshError::Expired),
592                    "expected Expired after failed refresh, got: {err:?}"
593                );
594            }
595
596            #[tokio::test]
597            async fn restores_refresh_token_after_failure() {
598                let mut mocks = MockSet::new();
599                mocks.mock(|when, then| {
600                    when.post().path("/oauth/token");
601                    then.bad_request().json(error_json("invalid_grant"));
602                });
603                let server = start_server(mocks).await;
604                let dir = tempfile::tempdir().unwrap();
605                let strategy =
606                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
607
608                // First call: refresh fails, returns Expired.
609                let err = strategy.get_token().await.unwrap_err();
610                assert!(
611                    matches!(err, AutoRefreshError::Expired),
612                    "expected Expired on first attempt, got: {err:?}"
613                );
614
615                // Verify the refresh token was restored so a retry is possible.
616                let state = strategy.state.lock().await;
617                assert!(
618                    state.token.is_some(),
619                    "token should still be cached after failed refresh"
620                );
621                assert!(
622                    state.token.as_ref().unwrap().refresh_token().is_some(),
623                    "refresh token should be restored for retry"
624                );
625                drop(state);
626
627                // Replace mock with a success response.
628                server.mocks().clear();
629                server.mocks().mock(|when, then| {
630                    when.post().path("/oauth/token");
631                    then.json(refresh_response_json("refreshed-token"));
632                });
633
634                // Second call: refresh token is available → retry succeeds.
635                let token = strategy.get_token().await.unwrap();
636                assert_eq!(
637                    token.as_str(),
638                    "refreshed-token",
639                    "retry should succeed with restored refresh token"
640                );
641            }
642
643            #[tokio::test]
644            async fn sequential_calls_only_refresh_once() {
645                let mut mocks = MockSet::new();
646                mocks.mock(|when, then| {
647                    when.post().path("/oauth/token");
648                    then.json(refresh_response_json("refreshed-once"));
649                });
650                let server = start_server(mocks).await;
651                let dir = tempfile::tempdir().unwrap();
652                let strategy =
653                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
654
655                // First call triggers refresh.
656                let token = strategy.get_token().await.unwrap();
657                assert_eq!(
658                    token.as_str(),
659                    "refreshed-once",
660                    "first call should trigger refresh"
661                );
662
663                // Swap mock to track if another refresh is attempted.
664                server.mocks().clear();
665                server.mocks().mock(|when, then| {
666                    when.post().path("/oauth/token");
667                    then.json(refresh_response_json("refreshed-twice"));
668                });
669
670                // Calls 2-5: the refreshed token is fresh, so no further refresh.
671                for _ in 0..4 {
672                    let token = strategy.get_token().await.unwrap();
673                    assert_eq!(
674                        token.as_str(),
675                        "refreshed-once",
676                        "should return cached refreshed token, not trigger another refresh"
677                    );
678                }
679            }
680
681            #[tokio::test]
682            async fn prevents_second_refresh_after_success() {
683                let mut mocks = MockSet::new();
684                mocks.mock(|when, then| {
685                    when.post().path("/oauth/token");
686                    then.json(refresh_response_json("refreshed-token"));
687                });
688                let server = start_server(mocks).await;
689                let dir = tempfile::tempdir().unwrap();
690                let strategy =
691                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
692
693                // First call refreshes successfully.
694                let token = strategy.get_token().await.unwrap();
695                assert_eq!(
696                    token.as_str(),
697                    "refreshed-token",
698                    "first call should refresh the token"
699                );
700
701                // Replace the mock with one that errors.
702                server.mocks().clear();
703                server.mocks().mock(|when, then| {
704                    when.post().path("/oauth/token");
705                    then.bad_request().json(error_json("should_not_be_called"));
706                });
707
708                // Second call should return the refreshed token without hitting
709                // the server again (the new token has a fresh expiry).
710                let token = strategy.get_token().await.unwrap();
711                assert_eq!(
712                    token.as_str(),
713                    "refreshed-token",
714                    "second call should return cached refreshed token"
715                );
716            }
717        }
718    }
719
720    mod given_expiring_but_usable_token {
721        use super::*;
722
723        mod when_refresh_fails {
724            use super::*;
725
726            #[tokio::test]
727            async fn returns_current_token() {
728                let mut mocks = MockSet::new();
729                mocks.mock(|when, then| {
730                    when.post().path("/oauth/token");
731                    then.bad_request().json(error_json("server_error"));
732                });
733                let server = start_server(mocks).await;
734                let dir = tempfile::tempdir().unwrap();
735                // Token expires in 30s (within the 90s leeway so is_expired() = true),
736                // but the access token is still technically usable.
737                let strategy =
738                    auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
739
740                // The refresh fails, but the access token should still be returned
741                // because it's still usable (30s remaining > 0).
742                let token = strategy.get_token().await.unwrap();
743                assert_eq!(
744                    token.as_str(),
745                    "still-usable",
746                    "should return still-usable token despite failed refresh"
747                );
748
749                // Verify the access token and refresh token are still present.
750                let state = strategy.state.lock().await;
751                assert!(state.token.is_some(), "token should still be cached");
752                assert_eq!(
753                    state.token.as_ref().unwrap().access_token().as_str(),
754                    "still-usable",
755                    "access token should be unchanged after failed refresh"
756                );
757                assert!(
758                    state.token.as_ref().unwrap().refresh_token().is_some(),
759                    "refresh token should be restored after failed refresh"
760                );
761            }
762
763            #[tokio::test]
764            async fn restores_refresh_token_for_retry() {
765                let mut mocks = MockSet::new();
766                mocks.mock(|when, then| {
767                    when.post().path("/oauth/token");
768                    then.bad_request().json(error_json("server_error"));
769                });
770                let server = start_server(mocks).await;
771                let dir = tempfile::tempdir().unwrap();
772                // Token expires in 30s — is_expired() = true, is_usable() = true.
773                let strategy =
774                    auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
775
776                // First call: refresh fails, but the still-usable token is returned.
777                let token = strategy.get_token().await.unwrap();
778                assert_eq!(
779                    token.as_str(),
780                    "still-usable",
781                    "first call should return still-usable token"
782                );
783
784                // Replace mock with a success response.
785                server.mocks().clear();
786                server.mocks().mock(|when, then| {
787                    when.post().path("/oauth/token");
788                    then.json(refresh_response_json("refreshed-token"));
789                });
790
791                // Second call: refresh token was restored, so the retry succeeds.
792                let token = strategy.get_token().await.unwrap();
793                assert!(
794                    token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
795                    "expected old or refreshed token, got: {}",
796                    token.as_str()
797                );
798
799                // Verify the cache now holds the refreshed token.
800                let state = strategy.state.lock().await;
801                assert_eq!(
802                    state.token.as_ref().unwrap().access_token().as_str(),
803                    "refreshed-token",
804                    "cache should hold the refreshed token after retry"
805                );
806            }
807        }
808    }
809
810    mod given_concurrent_callers {
811        use super::*;
812
813        #[tokio::test]
814        async fn returns_usable_token_while_refreshing() {
815            let mut mocks = MockSet::new();
816            mocks.mock(|when, then| {
817                when.post().path("/oauth/token");
818                then.json(refresh_response_json("refreshed-token"));
819            });
820            let server = start_server(mocks).await;
821            let dir = tempfile::tempdir().unwrap();
822            let strategy = Arc::new(auto_refresh_with_token(
823                &dir,
824                &server,
825                make_token("still-usable", 30, true),
826            ));
827
828            let s1 = Arc::clone(&strategy);
829            let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
830
831            let s2 = Arc::clone(&strategy);
832            let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
833
834            let (result_a, result_b) = tokio::join!(handle_a, handle_b);
835            let token_a = result_a.unwrap();
836            let token_b = result_b.unwrap();
837
838            assert!(
839                token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
840                "unexpected token_a: {}",
841                token_a.as_str()
842            );
843            assert!(
844                token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
845                "unexpected token_b: {}",
846                token_b.as_str()
847            );
848        }
849
850        #[tokio::test]
851        async fn blocks_until_refresh_completes() {
852            let mut mocks = MockSet::new();
853            mocks.mock(|when, then| {
854                when.post().path("/oauth/token");
855                then.json(refresh_response_json("refreshed-token"));
856            });
857            let server = start_server(mocks).await;
858            let dir = tempfile::tempdir().unwrap();
859            let strategy = Arc::new(auto_refresh_with_token(
860                &dir,
861                &server,
862                make_token("expired-token", 0, true),
863            ));
864
865            let s1 = Arc::clone(&strategy);
866            let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
867
868            let s2 = Arc::clone(&strategy);
869            let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
870
871            let (result_a, result_b) = tokio::join!(handle_a, handle_b);
872            let token_a = result_a.unwrap();
873            let token_b = result_b.unwrap();
874
875            assert_eq!(
876                token_a.as_str(),
877                "refreshed-token",
878                "caller a should receive refreshed token"
879            );
880            assert_eq!(
881                token_b.as_str(),
882                "refreshed-token",
883                "caller b should receive refreshed token"
884            );
885        }
886    }
887}
888
889#[cfg(test)]
890#[allow(clippy::unwrap_used)]
891mod stress_tests {
892    use super::*;
893    use crate::oauth_refresher::OAuthRefresher;
894    use crate::SecretToken;
895    use stack_profile::ProfileStore;
896    use std::sync::atomic::{AtomicUsize, Ordering};
897    use std::sync::Arc;
898    use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
899
900    /// Tracks in-flight and peak concurrency for test assertions.
901    #[derive(Clone)]
902    struct CountingState {
903        total: Arc<AtomicUsize>,
904        current: Arc<AtomicUsize>,
905        peak: Arc<AtomicUsize>,
906    }
907
908    impl CountingState {
909        fn new() -> Self {
910            Self {
911                total: Arc::new(AtomicUsize::new(0)),
912                current: Arc::new(AtomicUsize::new(0)),
913                peak: Arc::new(AtomicUsize::new(0)),
914            }
915        }
916
917        fn enter(&self) {
918            self.total.fetch_add(1, Ordering::SeqCst);
919            let prev = self.current.fetch_add(1, Ordering::SeqCst);
920            self.peak.fetch_max(prev + 1, Ordering::SeqCst);
921        }
922
923        fn exit(&self) {
924            self.current.fetch_sub(1, Ordering::SeqCst);
925        }
926
927        fn peak(&self) -> usize {
928            self.peak.load(Ordering::SeqCst)
929        }
930
931        fn total(&self) -> usize {
932            self.total.load(Ordering::SeqCst)
933        }
934    }
935
936    #[derive(Clone)]
937    struct DelayedRefreshState {
938        counting: CountingState,
939        delay: Duration,
940    }
941
942    async fn delayed_refresh_handler(
943        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
944    ) -> axum::Json<serde_json::Value> {
945        state.counting.enter();
946        tokio::time::sleep(state.delay).await;
947        state.counting.exit();
948        axum::Json(serde_json::json!({
949            "access_token": "refreshed-token",
950            "token_type": "Bearer",
951            "expires_in": 3600,
952            "refresh_token": "new-refresh-token"
953        }))
954    }
955
956    async fn delayed_error_handler(
957        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
958    ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
959        state.counting.enter();
960        tokio::time::sleep(state.delay).await;
961        state.counting.exit();
962        (
963            axum::http::StatusCode::BAD_REQUEST,
964            axum::Json(serde_json::json!({
965                "error": "invalid_grant",
966                "error_description": "invalid_grant occurred"
967            })),
968        )
969    }
970
971    async fn start_axum_server<H, T>(
972        handler: H,
973        state: DelayedRefreshState,
974    ) -> (url::Url, CountingState)
975    where
976        H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
977        T: 'static,
978    {
979        let counting = state.counting.clone();
980        let app = axum::Router::new()
981            .route("/oauth/token", axum::routing::post(handler))
982            .with_state(state);
983        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
984        let addr = listener.local_addr().unwrap();
985        tokio::spawn(async move {
986            axum::serve(listener, app).await.unwrap();
987        });
988        let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
989        (base_url, counting)
990    }
991
992    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
993        let now = SystemTime::now()
994            .duration_since(UNIX_EPOCH)
995            .unwrap()
996            .as_secs();
997
998        Token {
999            access_token: SecretToken::new(access),
1000            token_type: "Bearer".to_string(),
1001            expires_at: now + expires_in,
1002            refresh_token: if refresh {
1003                Some(SecretToken::new("test-refresh-token"))
1004            } else {
1005                None
1006            },
1007            region: None,
1008            client_id: None,
1009            device_instance_id: None,
1010        }
1011    }
1012
1013    fn auto_refresh_with_token(
1014        dir: &tempfile::TempDir,
1015        base_url: &url::Url,
1016        token: Token,
1017    ) -> AutoRefresh<OAuthRefresher> {
1018        let store = ProfileStore::new(dir.path());
1019        store.init_workspace("ZVATKW3VHMFG27DY").unwrap();
1020        let ws_store = store.current_workspace_store().unwrap();
1021        ws_store.save_profile(&token).unwrap();
1022        let refresher = OAuthRefresher::new(
1023            Some(ws_store),
1024            base_url.clone(),
1025            "cli",
1026            "ap-southeast-2.aws",
1027            None,
1028        );
1029        AutoRefresh::with_token(refresher, token)
1030    }
1031
1032    const CONCURRENCY: usize = 50;
1033
1034    mod given_fresh_token {
1035        use super::*;
1036
1037        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1038        async fn all_callers_return_immediately() {
1039            let counting = CountingState::new();
1040            let state = DelayedRefreshState {
1041                counting: counting.clone(),
1042                delay: Duration::from_millis(500),
1043            };
1044            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1045            let dir = tempfile::tempdir().unwrap();
1046            let strategy = Arc::new(auto_refresh_with_token(
1047                &dir,
1048                &base_url,
1049                make_token("fresh-token", 3600, true),
1050            ));
1051
1052            let start = Instant::now();
1053            let mut handles = Vec::with_capacity(CONCURRENCY);
1054            for _ in 0..CONCURRENCY {
1055                let s = Arc::clone(&strategy);
1056                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1057            }
1058
1059            let results: Vec<_> = {
1060                let mut results = Vec::with_capacity(handles.len());
1061                for handle in handles {
1062                    results.push(handle.await.unwrap());
1063                }
1064                results
1065            };
1066            let elapsed = start.elapsed();
1067
1068            for token in &results {
1069                assert_eq!(
1070                    token.as_str(),
1071                    "fresh-token",
1072                    "all callers should receive the fresh token"
1073                );
1074            }
1075
1076            assert!(
1077                elapsed < Duration::from_millis(200),
1078                "expected < 200ms for fresh tokens, got {:?}",
1079                elapsed
1080            );
1081            assert_eq!(stats.total(), 0, "no refresh requests should be made");
1082        }
1083    }
1084
1085    mod given_expiring_but_usable_token {
1086        use super::*;
1087
1088        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1089        async fn non_blocking_reads_during_refresh() {
1090            let counting = CountingState::new();
1091            let state = DelayedRefreshState {
1092                counting: counting.clone(),
1093                delay: Duration::from_millis(500),
1094            };
1095            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1096            let dir = tempfile::tempdir().unwrap();
1097            let strategy = Arc::new(auto_refresh_with_token(
1098                &dir,
1099                &base_url,
1100                make_token("still-usable", 30, true),
1101            ));
1102
1103            let start = Instant::now();
1104            let mut handles = Vec::with_capacity(CONCURRENCY);
1105            for _ in 0..CONCURRENCY {
1106                let s = Arc::clone(&strategy);
1107                handles.push(tokio::spawn(async move {
1108                    let call_start = Instant::now();
1109                    let token = s.get_token().await.unwrap();
1110                    (token, call_start.elapsed())
1111                }));
1112            }
1113
1114            let results: Vec<_> = {
1115                let mut results = Vec::with_capacity(handles.len());
1116                for handle in handles {
1117                    results.push(handle.await.unwrap());
1118                }
1119                results
1120            };
1121            let elapsed = start.elapsed();
1122
1123            for (token, _) in &results {
1124                assert!(
1125                    token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1126                    "unexpected token: {}",
1127                    token.as_str()
1128                );
1129            }
1130
1131            let fast_callers = results
1132                .iter()
1133                .filter(|(_, dur)| *dur < Duration::from_millis(100))
1134                .count();
1135            assert!(
1136                fast_callers >= CONCURRENCY - 1,
1137                "expected at least {} fast callers, got {} (total elapsed: {:?})",
1138                CONCURRENCY - 1,
1139                fast_callers,
1140                elapsed
1141            );
1142
1143            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1144            assert_eq!(stats.total(), 1, "total refresh requests");
1145        }
1146
1147        /// Reproduces the race condition where a token crosses real expiry during
1148        /// an in-flight non-blocking refresh. Before the fix, late-arriving callers
1149        /// would see `refresh_in_progress = true` + `!is_usable()` and return
1150        /// `Err(Expired)` instead of waiting for the refresh to complete.
1151        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1152        async fn waiters_receive_token_when_expiry_crosses() {
1153            // Token with 1s until real expiry (minimum granularity since
1154            // expires_at is in seconds). is_expired() = true (within 90s leeway),
1155            // is_usable() = true (1s remaining). Refresh takes 1.5s so the token
1156            // crosses real expiry mid-refresh.
1157            let refresh_delay = Duration::from_millis(1500);
1158            let counting = CountingState::new();
1159            let state = DelayedRefreshState {
1160                counting: counting.clone(),
1161                delay: refresh_delay,
1162            };
1163            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1164            let dir = tempfile::tempdir().unwrap();
1165            let strategy = Arc::new(auto_refresh_with_token(
1166                &dir,
1167                &base_url,
1168                make_token("expiring-soon", 1, true),
1169            ));
1170
1171            // First caller triggers the non-blocking refresh and gets the old token.
1172            let first = strategy.get_token().await.unwrap();
1173            assert_eq!(
1174                first.as_str(),
1175                "expiring-soon",
1176                "first caller should receive the expiring token"
1177            );
1178
1179            // Wait for the token to cross real expiry (but refresh is still in-flight).
1180            tokio::time::sleep(Duration::from_millis(1100)).await;
1181
1182            // Launch 50 concurrent callers. Without the fix, these would all get
1183            // Err(Expired) because refresh_in_progress = true and !is_usable().
1184            let mut handles = Vec::with_capacity(CONCURRENCY);
1185            for _ in 0..CONCURRENCY {
1186                let s = Arc::clone(&strategy);
1187                handles.push(tokio::spawn(async move { s.get_token().await }));
1188            }
1189
1190            let results: Vec<_> = {
1191                let mut results = Vec::with_capacity(handles.len());
1192                for handle in handles {
1193                    results.push(handle.await.unwrap());
1194                }
1195                results
1196            };
1197
1198            // All callers must succeed — none should get Expired.
1199            for (i, result) in results.iter().enumerate() {
1200                assert!(
1201                    result.is_ok(),
1202                    "caller {i} got Err({:?}), expected Ok",
1203                    result.as_ref().unwrap_err()
1204                );
1205                assert_eq!(
1206                    result.as_ref().unwrap().as_str(),
1207                    "refreshed-token",
1208                    "caller {i} should receive the refreshed token"
1209                );
1210            }
1211
1212            assert_eq!(stats.total(), 1, "only one refresh request should be made");
1213        }
1214    }
1215
1216    mod given_fully_expired_token {
1217        use super::*;
1218
1219        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1220        async fn all_callers_block_until_refresh() {
1221            let refresh_delay = Duration::from_millis(200);
1222            let counting = CountingState::new();
1223            let state = DelayedRefreshState {
1224                counting: counting.clone(),
1225                delay: refresh_delay,
1226            };
1227            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1228            let dir = tempfile::tempdir().unwrap();
1229            let strategy = Arc::new(auto_refresh_with_token(
1230                &dir,
1231                &base_url,
1232                make_token("expired-token", 0, true),
1233            ));
1234
1235            let start = Instant::now();
1236            let mut handles = Vec::with_capacity(CONCURRENCY);
1237            for _ in 0..CONCURRENCY {
1238                let s = Arc::clone(&strategy);
1239                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1240            }
1241
1242            let results: Vec<_> = {
1243                let mut results = Vec::with_capacity(handles.len());
1244                for handle in handles {
1245                    results.push(handle.await.unwrap());
1246                }
1247                results
1248            };
1249            let elapsed = start.elapsed();
1250
1251            for token in &results {
1252                assert_eq!(
1253                    token.as_str(),
1254                    "refreshed-token",
1255                    "all callers should receive refreshed token"
1256                );
1257            }
1258
1259            assert!(
1260                elapsed < refresh_delay + Duration::from_millis(200),
1261                "expected < {:?} for blocked callers, got {:?}",
1262                refresh_delay + Duration::from_millis(200),
1263                elapsed
1264            );
1265
1266            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1267            assert_eq!(stats.total(), 1, "total refresh requests");
1268        }
1269
1270        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1271        async fn all_callers_receive_expired_on_failure() {
1272            let counting = CountingState::new();
1273            let state = DelayedRefreshState {
1274                counting: counting.clone(),
1275                delay: Duration::from_millis(10),
1276            };
1277            let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1278            let dir = tempfile::tempdir().unwrap();
1279            let strategy = Arc::new(auto_refresh_with_token(
1280                &dir,
1281                &base_url,
1282                make_token("expired-token", 0, true),
1283            ));
1284
1285            let mut handles = Vec::with_capacity(CONCURRENCY);
1286            for _ in 0..CONCURRENCY {
1287                let s = Arc::clone(&strategy);
1288                handles.push(tokio::spawn(async move { s.get_token().await }));
1289            }
1290
1291            let results: Vec<_> = {
1292                let mut results = Vec::with_capacity(handles.len());
1293                for handle in handles {
1294                    results.push(handle.await.unwrap());
1295                }
1296                results
1297            };
1298
1299            for result in &results {
1300                assert!(result.is_err(), "expected Expired error, got Ok");
1301                let err = result.as_ref().unwrap_err();
1302                assert!(
1303                    matches!(err, AutoRefreshError::Expired),
1304                    "expected Expired, got: {err:?}"
1305                );
1306            }
1307
1308            let state = strategy.state.lock().await;
1309            assert!(
1310                state.token.as_ref().unwrap().refresh_token().is_some(),
1311                "refresh token should be restored after failed refresh"
1312            );
1313            drop(state);
1314
1315            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1316            assert!(
1317                stats.total() >= 1,
1318                "at least one refresh attempt should be made"
1319            );
1320        }
1321
1322        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1323        async fn retry_succeeds_after_failure() {
1324            // Phase 1: Server returns errors.
1325            let counting1 = CountingState::new();
1326            let state1 = DelayedRefreshState {
1327                counting: counting1.clone(),
1328                delay: Duration::from_millis(50),
1329            };
1330            let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1331            let dir = tempfile::tempdir().unwrap();
1332            let strategy = Arc::new(auto_refresh_with_token(
1333                &dir,
1334                &base_url,
1335                make_token("expired-token", 0, true),
1336            ));
1337
1338            let mut handles = Vec::with_capacity(CONCURRENCY);
1339            for _ in 0..CONCURRENCY {
1340                let s = Arc::clone(&strategy);
1341                handles.push(tokio::spawn(async move { s.get_token().await }));
1342            }
1343
1344            let results: Vec<_> = {
1345                let mut results = Vec::with_capacity(handles.len());
1346                for handle in handles {
1347                    results.push(handle.await.unwrap());
1348                }
1349                results
1350            };
1351
1352            for result in &results {
1353                assert!(
1354                    result.is_err(),
1355                    "first wave: expected Expired, got Ok({})",
1356                    result.as_ref().unwrap().as_str()
1357                );
1358            }
1359
1360            // Phase 2: New server that returns success.
1361            let counting2 = CountingState::new();
1362            let state2 = DelayedRefreshState {
1363                counting: counting2.clone(),
1364                delay: Duration::from_millis(50),
1365            };
1366            let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1367
1368            let strategy2 = Arc::new(auto_refresh_with_token(
1369                &dir,
1370                &base_url2,
1371                make_token("expired-token", 0, true),
1372            ));
1373
1374            let mut handles = Vec::with_capacity(CONCURRENCY);
1375            for _ in 0..CONCURRENCY {
1376                let s = Arc::clone(&strategy2);
1377                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1378            }
1379
1380            let results: Vec<_> = {
1381                let mut results = Vec::with_capacity(handles.len());
1382                for handle in handles {
1383                    results.push(handle.await.unwrap());
1384                }
1385                results
1386            };
1387
1388            for token in &results {
1389                assert_eq!(
1390                    token.as_str(),
1391                    "refreshed-token",
1392                    "retry callers should receive refreshed token"
1393                );
1394            }
1395
1396            assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1397        }
1398    }
1399
1400    mod given_cancelled_refresh {
1401        use super::*;
1402
1403        /// If a blocking refresh (fully expired token) is cancelled mid-flight,
1404        /// the `CancelGuard` must reset `refresh_in_progress` and notify waiters
1405        /// so the next caller doesn't hang in `wait_for_in_flight_refresh`.
1406        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1407        async fn blocked_callers_recover_after_cancellation() {
1408            let counting = CountingState::new();
1409            let state = DelayedRefreshState {
1410                counting: counting.clone(),
1411                delay: Duration::from_secs(10), // Very slow — will be cancelled
1412            };
1413            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1414            let dir = tempfile::tempdir().unwrap();
1415            let strategy = Arc::new(auto_refresh_with_token(
1416                &dir,
1417                &base_url,
1418                make_token("expired-token", 0, true),
1419            ));
1420
1421            // Spawn get_token and let the blocking refresh start.
1422            let s = Arc::clone(&strategy);
1423            let handle = tokio::spawn(async move { s.get_token().await });
1424            tokio::time::sleep(Duration::from_millis(100)).await;
1425
1426            // Cancel the refresh mid-flight.
1427            handle.abort();
1428            let _ = handle.await;
1429
1430            // The next caller must not hang. The credential is lost (refresh
1431            // token was taken before the HTTP call), so the result is Expired,
1432            // but the important thing is that it completes promptly.
1433            let s = Arc::clone(&strategy);
1434            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1435
1436            assert!(
1437                result.is_ok(),
1438                "get_token() should not hang after cancelled blocking refresh"
1439            );
1440        }
1441
1442        /// If a non-blocking refresh (expiring-but-usable token) is cancelled
1443        /// mid-flight, the `CancelGuard` must reset `refresh_in_progress` and
1444        /// notify waiters so they don't hang once the token crosses real expiry.
1445        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1446        async fn non_blocking_callers_recover_after_cancellation() {
1447            let counting = CountingState::new();
1448            let state = DelayedRefreshState {
1449                counting: counting.clone(),
1450                delay: Duration::from_secs(10), // Very slow — will be cancelled
1451            };
1452            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1453            let dir = tempfile::tempdir().unwrap();
1454            // Token expires in 30s — is_expired() = true, is_usable() = true.
1455            let strategy = Arc::new(auto_refresh_with_token(
1456                &dir,
1457                &base_url,
1458                make_token("still-usable", 30, true),
1459            ));
1460
1461            // Spawn get_token — triggers non-blocking refresh, drops lock, then
1462            // blocks on the slow HTTP call.
1463            let s = Arc::clone(&strategy);
1464            let handle = tokio::spawn(async move { s.get_token().await });
1465            tokio::time::sleep(Duration::from_millis(100)).await;
1466
1467            // Cancel the refresh mid-flight.
1468            handle.abort();
1469            let _ = handle.await;
1470
1471            // The next caller must not hang. The token is still usable so it
1472            // should be returned even though the refresh was cancelled.
1473            let s = Arc::clone(&strategy);
1474            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1475
1476            assert!(
1477                result.is_ok(),
1478                "get_token() should not hang after cancelled non-blocking refresh"
1479            );
1480            let result = result.unwrap();
1481            assert!(
1482                result.is_ok(),
1483                "expected Ok with still-usable token, got: {:?}",
1484                result.unwrap_err()
1485            );
1486        }
1487    }
1488}