Skip to main content

stack_auth/
oidc_refresher.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use cts_common::WorkspaceId;
5use url::Url;
6use web_time::{SystemTime, UNIX_EPOCH};
7
8use crate::authorize_dto::AuthoriseResponse;
9use crate::refresher::Refresher;
10use crate::{http_client, AuthError, SecretToken, Token};
11
12/// Asynchronously supplies the *current* third-party OIDC JWT to federate.
13///
14/// [`OidcFederationStrategy`](crate::OidcFederationStrategy) re-invokes this on every refresh:
15/// `/api/authorise` issues no CTS refresh token, so renewing an expired CTS
16/// token means re-federating with a fresh provider JWT. Implementations
17/// typically wrap a provider SDK call (`clerk.session.getToken()`,
18/// `supabase.auth.getSession()`), an FFI callback, or a test double.
19///
20/// On native targets the trait carries `Send + Sync` bounds so the provider
21/// can be driven from `tokio::spawn` background work. On wasm32 the bounds
22/// are dropped — reqwest's fetch-backed futures are not `Send` and edge
23/// runtimes are single-threaded anyway.
24#[cfg(not(target_arch = "wasm32"))]
25pub trait OidcProvider: Send + Sync {
26    /// Fetch the current third-party OIDC JWT to federate.
27    fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> + Send;
28}
29
30/// Wasm32 variant of [`OidcProvider`] — drops the `Send + Sync` bounds.
31#[cfg(target_arch = "wasm32")]
32pub trait OidcProvider {
33    /// Fetch the current third-party OIDC JWT to federate.
34    fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>>;
35}
36
37/// [`OidcProvider`] backed by a user-supplied async closure.
38///
39/// The closure fires on every federation — initial auth and every
40/// re-federation after expiry — so it must return the *current* JWT each
41/// time, not a value captured once. The point of the closure is to defer to
42/// the provider's own session machinery on every call: a provider SDK
43/// (`clerk.session.getToken()`, `supabase.auth.getSession()`) hands back a
44/// freshly-minted short-lived JWT, transparently refreshing its own session
45/// as needed. Capturing a single token up front would instead pin a JWT that
46/// expires and can never be renewed.
47///
48/// # Example
49///
50/// ```no_run
51/// use stack_auth::{AuthError, OidcProviderFn, SecretToken};
52///
53/// # async fn clerk_session_get_token() -> Result<String, AuthError> { Ok(String::new()) }
54/// // Each call asks the provider SDK for the *current* session token, so an
55/// // expired JWT is refreshed upstream rather than reused.
56/// let provider = OidcProviderFn::new(|| async {
57///     let jwt = clerk_session_get_token().await?;
58///     Ok::<_, AuthError>(SecretToken::new(jwt))
59/// });
60/// ```
61pub struct OidcProviderFn<F> {
62    fetch: F,
63}
64
65impl<F> OidcProviderFn<F> {
66    /// Build an `OidcProviderFn` from an async closure returning the current JWT.
67    pub fn new(fetch: F) -> Self {
68        Self { fetch }
69    }
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73impl<F, Fut> OidcProvider for OidcProviderFn<F>
74where
75    F: Fn() -> Fut + Send + Sync,
76    Fut: Future<Output = Result<SecretToken, AuthError>> + Send,
77{
78    fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> + Send {
79        (self.fetch)()
80    }
81}
82
83#[cfg(target_arch = "wasm32")]
84impl<F, Fut> OidcProvider for OidcProviderFn<F>
85where
86    F: Fn() -> Fut,
87    Fut: Future<Output = Result<SecretToken, AuthError>>,
88{
89    fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> {
90        (self.fetch)()
91    }
92}
93
94/// A [`Refresher`] that federates a third-party OIDC JWT into a CTS service
95/// token via `POST /api/authorise`.
96///
97/// *Our* federation step is stateless: the credential is always available (the
98/// [`OidcProvider`] is re-callable), so `try_credential` returns `Some(())` and
99/// `restore` is a no-op — exactly like
100/// [`AccessKeyRefresher`](crate::access_key_refresher). The *upstream* OIDC
101/// provider that issues the JWT is typically not stateless — it usually relies
102/// on its own session machinery (cookies, a session store, a refresh token)
103/// to mint the short-lived JWT that [`OidcProvider::fetch`] returns.
104///
105/// `/api/authorise` issues no CTS refresh token. Keeping the federated CTS
106/// token fresh is therefore the upstream caller's responsibility, via whatever
107/// mechanism the provider requires: when the CTS token expires, `AutoRefresh`
108/// renews it by calling `refresh` again, which re-invokes the `OidcProvider`
109/// for a current JWT — so a provider that hands back an expired or stale JWT
110/// will produce an expired or stale CTS token in turn.
111pub(crate) struct OidcRefresher<P> {
112    oidc_provider: P,
113    workspace_id: WorkspaceId,
114    base_url: Url,
115    http_client: Arc<reqwest::Client>,
116}
117
118impl<P> OidcRefresher<P> {
119    pub(crate) fn new(oidc_provider: P, workspace_id: WorkspaceId, base_url: Url) -> Self {
120        Self {
121            oidc_provider,
122            workspace_id,
123            base_url,
124            http_client: Arc::new(http_client()),
125        }
126    }
127}
128
129impl<P: OidcProvider> Refresher for OidcRefresher<P> {
130    type Credential = ();
131
132    fn save(&self, _token: &Token) {
133        // Federated tokens are ephemeral — no per-refresher persistence.
134    }
135
136    fn try_credential(&self, _token: Option<&mut Token>) -> Option<Self::Credential> {
137        // The OIDC provider is always re-callable, so federation can always be
138        // attempted — including on cold start (initial auth).
139        Some(())
140    }
141
142    fn restore(&self, _token: &mut Token, _credential: Self::Credential) {
143        // Nothing to restore — the OIDC provider is re-callable.
144    }
145
146    async fn refresh(&self, _credential: &Self::Credential) -> Result<Token, AuthError> {
147        let oidc_token = self.oidc_provider.fetch().await?;
148
149        let url = self.base_url.join("api/authorise")?;
150        tracing::debug!(url = %url, "federating OIDC token");
151
152        let resp = self
153            .http_client
154            .post(url)
155            .json(&OidcAuthoriseRequest {
156                oidc_token: oidc_token.as_str(),
157                workspace_id: self.workspace_id.as_str(),
158            })
159            .send()
160            .await?;
161
162        if !resp.status().is_success() {
163            let status = resp.status();
164            let body = resp.text().await.unwrap_or_default();
165            tracing::debug!(%status, %body, "OIDC federation failed");
166            return Err(AuthError::Server(format!("{status}: {body}")));
167        }
168
169        let auth_resp: AuthoriseResponse = resp.json().await?;
170        let now = SystemTime::now()
171            .duration_since(UNIX_EPOCH)
172            .unwrap_or_default()
173            .as_secs();
174
175        Ok(Token {
176            access_token: auth_resp.access_token,
177            token_type: "Bearer".to_string(),
178            expires_at: now + auth_resp.expiry,
179            refresh_token: None,
180            region: None,
181            client_id: None,
182            device_instance_id: None,
183        })
184    }
185}
186
187#[derive(serde::Serialize)]
188#[serde(rename_all = "camelCase")]
189struct OidcAuthoriseRequest<'a> {
190    oidc_token: &'a str,
191    workspace_id: &'a str,
192}
193
194#[cfg(test)]
195#[allow(clippy::unwrap_used)]
196mod tests {
197    use std::sync::atomic::{AtomicUsize, Ordering};
198    use std::sync::Arc;
199    use std::time::{SystemTime, UNIX_EPOCH};
200
201    use mocktail::prelude::*;
202
203    use super::*;
204    use crate::auto_refresh::{AutoRefresh, AutoRefreshError};
205    use crate::TokenStore;
206
207    const WORKSPACE_ID: &str = "ZVATKW3VHMFG27DY";
208
209    fn workspace_id() -> WorkspaceId {
210        WORKSPACE_ID.parse().unwrap()
211    }
212
213    fn auth_response_json(access: &str, expiry: u64) -> serde_json::Value {
214        serde_json::json!({ "accessToken": access, "expiry": expiry })
215    }
216
217    async fn start_server(mocks: MockSet) -> MockServer {
218        let server = MockServer::new_http("oidc-refresher-test").with_mocks(mocks);
219        server.start().await.unwrap();
220        server
221    }
222
223    /// A [`OidcProvider`] test double that counts invocations and returns a
224    /// distinct JWT each call (`jwt-0`, `jwt-1`, …).
225    fn counting_provider() -> (Arc<AtomicUsize>, impl OidcProvider) {
226        let calls = Arc::new(AtomicUsize::new(0));
227        let calls_clone = Arc::clone(&calls);
228        let provider = OidcProviderFn::new(move || {
229            let calls = Arc::clone(&calls_clone);
230            async move {
231                let n = calls.fetch_add(1, Ordering::SeqCst);
232                Ok(SecretToken::new(format!("jwt-{n}")))
233            }
234        });
235        (calls, provider)
236    }
237
238    fn make_strategy<P: OidcProvider>(
239        server: &MockServer,
240        provider: P,
241    ) -> AutoRefresh<OidcRefresher<P>> {
242        let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
243        AutoRefresh::with_store(refresher, crate::NoStore)
244    }
245
246    fn make_token(access: &str, expires_in_secs: u64) -> Token {
247        let now = SystemTime::now()
248            .duration_since(UNIX_EPOCH)
249            .unwrap()
250            .as_secs();
251        Token {
252            access_token: SecretToken::new(access),
253            token_type: "Bearer".to_string(),
254            expires_at: now + expires_in_secs,
255            refresh_token: None,
256            region: None,
257            client_id: None,
258            device_instance_id: None,
259        }
260    }
261
262    #[tokio::test]
263    async fn test_initial_federation() {
264        let mut mocks = MockSet::new();
265        mocks.mock(|when, then| {
266            when.post().path("/api/authorise");
267            then.json(auth_response_json("cts-token", 3600));
268        });
269        let server = start_server(mocks).await;
270        let (calls, provider) = counting_provider();
271        let strategy = make_strategy(&server, provider);
272
273        let token = strategy.get_token().await.unwrap();
274
275        assert_eq!(token.as_str(), "cts-token");
276        assert_eq!(
277            calls.load(Ordering::SeqCst),
278            1,
279            "initial federation should invoke the OIDC provider once"
280        );
281    }
282
283    #[test]
284    fn test_request_serialization() {
285        let body = serde_json::to_value(OidcAuthoriseRequest {
286            oidc_token: "the-jwt",
287            workspace_id: WORKSPACE_ID,
288        })
289        .unwrap();
290        assert_eq!(
291            body,
292            serde_json::json!({ "oidcToken": "the-jwt", "workspaceId": WORKSPACE_ID }),
293            "request body should carry exactly the OIDC token and workspace ID"
294        );
295    }
296
297    #[tokio::test]
298    async fn test_caches_token_after_initial_federation() {
299        let mut mocks = MockSet::new();
300        mocks.mock(|when, then| {
301            when.post().path("/api/authorise");
302            then.json(auth_response_json("cts-token", 3600));
303        });
304        let server = start_server(mocks).await;
305        let (calls, provider) = counting_provider();
306        let strategy = make_strategy(&server, provider);
307
308        assert_eq!(strategy.get_token().await.unwrap().as_str(), "cts-token");
309
310        // Replace the mock so a second federation call would fail loudly.
311        server.mocks().clear();
312        server.mocks().mock(|when, then| {
313            when.post().path("/api/authorise");
314            then.internal_server_error()
315                .json(serde_json::json!({"error": "should not be called"}));
316        });
317
318        assert_eq!(strategy.get_token().await.unwrap().as_str(), "cts-token");
319        assert_eq!(
320            calls.load(Ordering::SeqCst),
321            1,
322            "cached token should be returned without re-federating"
323        );
324    }
325
326    #[tokio::test]
327    async fn test_re_federates_on_expiry() {
328        let mut mocks = MockSet::new();
329        mocks.mock(|when, then| {
330            when.post().path("/api/authorise");
331            then.json(auth_response_json("re-federated-token", 3600));
332        });
333        let server = start_server(mocks).await;
334
335        let (calls, provider) = counting_provider();
336        // Pre-seed the store with an already-expired token.
337        let store = Arc::new(crate::InMemoryTokenStore::new());
338        store.save(&make_token("stale-cts-token", 0)).await;
339
340        let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
341        let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
342
343        let token = strategy.get_token().await.unwrap();
344        assert_eq!(
345            token.as_str(),
346            "re-federated-token",
347            "expired cached token should trigger re-federation"
348        );
349        assert_eq!(
350            calls.load(Ordering::SeqCst),
351            1,
352            "re-federation should invoke the OIDC provider for a current JWT"
353        );
354    }
355
356    #[tokio::test]
357    async fn test_oidc_provider_failure_propagates() {
358        let mut mocks = MockSet::new();
359        mocks.mock(|when, then| {
360            when.post().path("/api/authorise");
361            then.json(auth_response_json("unreachable", 3600));
362        });
363        let server = start_server(mocks).await;
364
365        let provider = OidcProviderFn::new(|| async {
366            Err::<SecretToken, _>(AuthError::Server("provider exploded".to_string()))
367        });
368        let strategy = make_strategy(&server, provider);
369
370        let err = strategy.get_token().await.unwrap_err();
371        assert!(
372            matches!(err, AutoRefreshError::Auth(AuthError::Server(_))),
373            "OIDC provider failure should surface as an auth error, got: {err:?}"
374        );
375    }
376
377    #[tokio::test]
378    async fn test_server_rejection_propagates() {
379        let mut mocks = MockSet::new();
380        mocks.mock(|when, then| {
381            when.post().path("/api/authorise");
382            then.internal_server_error()
383                .json(serde_json::json!({"error": "workspace mismatch"}));
384        });
385        let server = start_server(mocks).await;
386        let (_calls, provider) = counting_provider();
387        let strategy = make_strategy(&server, provider);
388
389        let err = strategy.get_token().await.unwrap_err();
390        assert!(
391            matches!(err, AutoRefreshError::Auth(AuthError::Server(_))),
392            "a 500 from /api/authorise should surface as a server error, got: {err:?}"
393        );
394    }
395
396    #[tokio::test]
397    async fn test_loads_token_from_store_on_cold_start_no_http() {
398        let mut mocks = MockSet::new();
399        mocks.mock(|when, then| {
400            when.post().path("/api/authorise");
401            then.internal_server_error()
402                .json(serde_json::json!({"error": "should not be called"}));
403        });
404        let server = start_server(mocks).await;
405
406        let store = Arc::new(crate::InMemoryTokenStore::new());
407        store.save(&make_token("from-store", 3600)).await;
408
409        let (calls, provider) = counting_provider();
410        let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
411        let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
412
413        let token = strategy.get_token().await.unwrap();
414        assert_eq!(token.as_str(), "from-store");
415        assert_eq!(
416            calls.load(Ordering::SeqCst),
417            0,
418            "a fresh cached token should be used without invoking the OIDC provider"
419        );
420    }
421
422    #[tokio::test]
423    async fn test_persists_token_to_store_after_federation() {
424        let mut mocks = MockSet::new();
425        mocks.mock(|when, then| {
426            when.post().path("/api/authorise");
427            then.json(auth_response_json("freshly-federated", 3600));
428        });
429        let server = start_server(mocks).await;
430
431        let store = Arc::new(crate::InMemoryTokenStore::new());
432        let (_calls, provider) = counting_provider();
433        let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
434        let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
435
436        let token = strategy.get_token().await.unwrap();
437        assert_eq!(token.as_str(), "freshly-federated");
438
439        let saved = store
440            .load()
441            .await
442            .expect("store should hold a token after federation");
443        assert_eq!(saved.access_token().as_str(), "freshly-federated");
444    }
445}