Skip to main content

stack_auth/
oidc_federation_strategy.rs

1use cts_common::{CtsServiceDiscovery, Region, ServiceDiscovery, WorkspaceId};
2
3use crate::auto_refresh::AutoRefresh;
4use crate::oidc_refresher::{OidcProvider, OidcRefresher};
5use crate::token_store::{NoStore, TokenStore};
6use crate::{ensure_trailing_slash, AuthError, AuthStrategy, ServiceToken};
7
8/// An [`AuthStrategy`] that federates a third-party OIDC JWT (Clerk, Supabase,
9/// Auth0, …) into a CipherStash CTS service token via `POST /api/authorise`.
10///
11/// Each call to [`get_token`](AuthStrategy::get_token) returns a cached CTS
12/// token until it expires. Because `/api/authorise` issues no CTS refresh
13/// token, renewal means *re-federating*: the strategy calls the
14/// [`OidcProvider`] again for a current third-party JWT and exchanges it for a
15/// fresh CTS token. Supply an `OidcProvider` that returns the live provider
16/// token each time (e.g. wrapping `clerk.session.getToken()`).
17///
18/// Every returned token is checked against the configured workspace — the
19/// same post-auth verification [`AccessKeyStrategy`](crate::AccessKeyStrategy)
20/// performs — so a token CTS minted for a different workspace (or one loaded
21/// from a poisoned shared cache) is never handed back. Verification can fail
22/// in two ways:
23///
24/// - [`AuthError::WorkspaceMismatch`] — the JWT decoded cleanly but its
25///   `workspace` claim doesn't match the configured workspace ID.
26/// - [`AuthError::InvalidToken`] — the JWT is malformed or missing the
27///   `workspace` claim entirely, so verification can't run.
28///
29/// When constructed via [`OidcFederationStrategyBuilder::with_token_store`], the strategy
30/// also persists tokens through an external [`TokenStore`] so short-lived
31/// instances (e.g. one per Edge Function request) can share a cache and skip
32/// re-federating on every cold start. The workspace check runs on cached and
33/// store-loaded tokens too, not just freshly federated ones.
34///
35/// # Example
36///
37/// ```no_run
38/// use stack_auth::{AuthError, OidcProviderFn, OidcFederationStrategy, SecretToken};
39/// use cts_common::{Region, WorkspaceId};
40///
41/// let region = Region::aws("ap-southeast-2").unwrap();
42/// let workspace_id: WorkspaceId = "ZVATKW3VHMFG27DY".parse().unwrap();
43/// let provider = OidcProviderFn::new(|| async {
44///     // Real consumers call into a provider SDK / FFI to fetch a live JWT.
45///     Ok::<_, AuthError>(SecretToken::new("header.payload.signature".to_string()))
46/// });
47/// let strategy = OidcFederationStrategy::new(region, workspace_id, provider).unwrap();
48/// ```
49pub struct OidcFederationStrategy<P, S = NoStore> {
50    inner: AutoRefresh<OidcRefresher<P>, S>,
51    expected_workspace: WorkspaceId,
52}
53
54impl<P: OidcProvider> OidcFederationStrategy<P> {
55    /// Create a new `OidcFederationStrategy` for the given region, workspace, and
56    /// OIDC provider.
57    ///
58    /// The auth endpoint is resolved automatically via service discovery.
59    pub fn new(
60        region: Region,
61        workspace_id: WorkspaceId,
62        oidc_provider: P,
63    ) -> Result<Self, AuthError> {
64        Self::builder(region, workspace_id, oidc_provider).build()
65    }
66
67    /// Return a builder for configuring an `OidcFederationStrategy` before construction.
68    pub fn builder(
69        region: Region,
70        workspace_id: WorkspaceId,
71        oidc_provider: P,
72    ) -> OidcFederationStrategyBuilder<P> {
73        OidcFederationStrategyBuilder {
74            region,
75            workspace_id,
76            oidc_provider,
77            base_url_override: None,
78            token_store: NoStore,
79        }
80    }
81}
82
83impl<P: OidcProvider, S: TokenStore> AuthStrategy for &OidcFederationStrategy<P, S> {
84    async fn get_token(self) -> Result<ServiceToken, AuthError> {
85        let token: ServiceToken = self.inner.get_token().await?;
86        let token_workspace = *token.workspace_id()?;
87        if token_workspace != self.expected_workspace {
88            return Err(AuthError::WorkspaceMismatch {
89                expected_workspace: self.expected_workspace,
90                token_workspace,
91            });
92        }
93        Ok(token)
94    }
95}
96
97/// Builder for [`OidcFederationStrategy`].
98///
99/// Created via [`OidcFederationStrategy::builder`].
100pub struct OidcFederationStrategyBuilder<P, S = NoStore> {
101    region: Region,
102    workspace_id: WorkspaceId,
103    oidc_provider: P,
104    base_url_override: Option<url::Url>,
105    token_store: S,
106}
107
108impl<P, S> OidcFederationStrategyBuilder<P, S> {
109    /// Override the base URL resolved by service discovery.
110    ///
111    /// Useful for pointing at a local or mock auth server during testing.
112    #[cfg(any(test, feature = "test-utils"))]
113    pub fn base_url(mut self, url: url::Url) -> Self {
114        self.base_url_override = Some(url);
115        self
116    }
117
118    /// Wire an external [`TokenStore`] into the strategy.
119    ///
120    /// On every call to [`get_token`](AuthStrategy::get_token), if no token is
121    /// cached in memory, the store is consulted before falling back to
122    /// re-federating. After every successful federation the new token is
123    /// written back to the store. Use this from short-lived strategy instances
124    /// (Edge Functions, Workers) to share a service-token cache across
125    /// processes — e.g. an HTTP-only cookie.
126    ///
127    /// Returns a new builder with the store type erased into the chain — see
128    /// [`InMemoryTokenStore`](crate::InMemoryTokenStore) and
129    /// [`TokenStoreFn`](crate::TokenStoreFn) for ready-made implementations.
130    pub fn with_token_store<T: TokenStore>(self, store: T) -> OidcFederationStrategyBuilder<P, T> {
131        OidcFederationStrategyBuilder {
132            region: self.region,
133            workspace_id: self.workspace_id,
134            oidc_provider: self.oidc_provider,
135            base_url_override: self.base_url_override,
136            token_store: store,
137        }
138    }
139}
140
141impl<P: OidcProvider, S: TokenStore> OidcFederationStrategyBuilder<P, S> {
142    /// Build the [`OidcFederationStrategy`].
143    ///
144    /// Resolves the base URL via service discovery unless overridden with
145    /// `base_url` (available when the `test-utils` feature is enabled).
146    pub fn build(self) -> Result<OidcFederationStrategy<P, S>, AuthError> {
147        let base_url = match self.base_url_override {
148            Some(url) => url,
149            None => crate::cts_base_url_from_env()?
150                .unwrap_or(CtsServiceDiscovery::endpoint(self.region)?),
151        };
152        let expected_workspace = self.workspace_id;
153        let refresher = OidcRefresher::new(
154            self.oidc_provider,
155            self.workspace_id,
156            ensure_trailing_slash(base_url),
157        );
158        Ok(OidcFederationStrategy {
159            inner: AutoRefresh::with_store(refresher, self.token_store),
160            expected_workspace,
161        })
162    }
163}
164
165#[cfg(test)]
166#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
167mod tests {
168    use std::sync::Arc;
169    use std::time::{SystemTime, UNIX_EPOCH};
170
171    use cts_common::Region;
172    use mocktail::prelude::*;
173
174    use super::*;
175    use crate::oidc_refresher::OidcProviderFn;
176    use crate::{InMemoryTokenStore, SecretToken, Token, TokenStore};
177
178    /// Mint an unsigned JWT carrying the given `workspace` claim. The strategy
179    /// decodes claims without verifying the signature (it already holds the
180    /// token), so an unsigned token is sufficient to exercise verification.
181    fn jwt_with_workspace(workspace: &str) -> String {
182        use jsonwebtoken::{encode, EncodingKey, Header};
183        let now = SystemTime::now()
184            .duration_since(UNIX_EPOCH)
185            .expect("system clock")
186            .as_secs();
187        let claims = serde_json::json!({
188            "iss": "https://cts.example.com/",
189            "sub": "CS|test-user",
190            "aud": "test-audience",
191            "iat": now,
192            "exp": now + 3600,
193            "workspace": workspace,
194            "scope": "",
195        });
196        encode(
197            &Header::default(),
198            &claims,
199            &EncodingKey::from_secret(b"test-secret"),
200        )
201        .expect("JWT encode")
202    }
203
204    /// A mock CTS that federates any OIDC token into a CTS token carrying the
205    /// given `workspace` claim.
206    async fn start_mock_server_returning_jwt(workspace: &str) -> MockServer {
207        let mut mocks = MockSet::new();
208        let jwt = jwt_with_workspace(workspace);
209        mocks.mock(move |when, then| {
210            when.post().path("/api/authorise");
211            then.json(serde_json::json!({ "accessToken": jwt, "expiry": 3600 }));
212        });
213        let server =
214            MockServer::new_http("oidc-federation-strategy-workspace-test").with_mocks(mocks);
215        server.start().await.expect("mock server start");
216        server
217    }
218
219    fn test_region() -> Region {
220        Region::aws("ap-southeast-2").expect("region parses")
221    }
222
223    fn provider() -> OidcProviderFn<impl Fn() -> std::future::Ready<Result<SecretToken, AuthError>>>
224    {
225        OidcProviderFn::new(|| {
226            std::future::ready(Ok(SecretToken::new("header.payload.signature".to_string())))
227        })
228    }
229
230    /// Happy path — the federated token's `workspace` claim matches the
231    /// configured workspace: `get_token()` returns the token cleanly.
232    #[tokio::test]
233    async fn returns_token_when_workspace_matches() {
234        const WS: &str = "ZVATKW3VHMFG27DY";
235        let server = start_mock_server_returning_jwt(WS).await;
236
237        let strategy =
238            OidcFederationStrategy::builder(test_region(), WS.parse().unwrap(), provider())
239                .base_url(server.url(""))
240                .build()
241                .expect("builder");
242
243        let token = (&strategy).get_token().await.expect("get_token");
244        assert_eq!(
245            token.workspace_id().expect("workspace_id").as_str(),
246            WS,
247            "happy-path token should carry the expected workspace",
248        );
249    }
250
251    /// Mismatch — CTS federates the OIDC token into a CTS token for a
252    /// *different* workspace than the strategy was configured for. This is the
253    /// security-critical case: the OIDC provider could be authenticated for a
254    /// workspace the caller didn't intend. `get_token()` must return
255    /// `WorkspaceMismatch`, not the token.
256    #[tokio::test]
257    async fn errors_when_token_workspace_differs() {
258        const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
259        const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
260        let server = start_mock_server_returning_jwt(TOKEN_WS).await;
261
262        let strategy = OidcFederationStrategy::builder(
263            test_region(),
264            EXPECTED_WS.parse().unwrap(),
265            provider(),
266        )
267        .base_url(server.url(""))
268        .build()
269        .expect("builder");
270
271        let err = (&strategy)
272            .get_token()
273            .await
274            .expect_err("expected mismatch");
275        match err {
276            AuthError::WorkspaceMismatch {
277                expected_workspace,
278                token_workspace,
279            } => {
280                assert_eq!(expected_workspace.as_str(), EXPECTED_WS);
281                assert_eq!(token_workspace.as_str(), TOKEN_WS);
282            }
283            other => panic!("expected WorkspaceMismatch, got {other:?}"),
284        }
285    }
286
287    /// A malformed CTS token (not a JWT) can't be decoded, so verification
288    /// can't run — `get_token()` surfaces `InvalidToken` rather than handing
289    /// back an unverifiable token.
290    #[tokio::test]
291    async fn errors_with_invalid_token_when_jwt_malformed() {
292        let mut mocks = MockSet::new();
293        mocks.mock(|when, then| {
294            when.post().path("/api/authorise");
295            then.json(serde_json::json!({ "accessToken": "not-a-jwt", "expiry": 3600 }));
296        });
297        let server =
298            MockServer::new_http("oidc-federation-strategy-malformed-test").with_mocks(mocks);
299        server.start().await.expect("mock server start");
300
301        let strategy = OidcFederationStrategy::builder(
302            test_region(),
303            "ZVATKW3VHMFG27DY".parse().unwrap(),
304            provider(),
305        )
306        .base_url(server.url(""))
307        .build()
308        .expect("builder");
309
310        let err = (&strategy)
311            .get_token()
312            .await
313            .expect_err("expected invalid-token error");
314        assert!(
315            matches!(err, AuthError::InvalidToken(_)),
316            "expected InvalidToken, got {err:?}",
317        );
318    }
319
320    /// A pre-populated [`TokenStore`] returning a token for a *different*
321    /// workspace must still be rejected by the strategy wrapper — the same
322    /// poisoned-shared-cache interaction `AccessKeyStrategy` guards against.
323    /// A 500-returning mock fails the test loudly if the strategy ever
324    /// re-federates instead of trusting (and rejecting) the stored token.
325    #[tokio::test]
326    async fn rejects_stored_token_for_different_workspace() {
327        const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
328        const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
329
330        let mut mocks = MockSet::new();
331        mocks.mock(|when, then| {
332            when.post().path("/api/authorise");
333            then.internal_server_error()
334                .json(serde_json::json!({"error": "store must satisfy the request"}));
335        });
336        let server =
337            MockServer::new_http("oidc-federation-strategy-store-mismatch-test").with_mocks(mocks);
338        server.start().await.expect("mock server start");
339
340        let now = SystemTime::now()
341            .duration_since(UNIX_EPOCH)
342            .expect("system clock")
343            .as_secs();
344        let stored = Token {
345            access_token: SecretToken::new(jwt_with_workspace(TOKEN_WS)),
346            token_type: "Bearer".to_string(),
347            expires_at: now + 3600,
348            refresh_token: None,
349            region: None,
350            client_id: None,
351            device_instance_id: None,
352        };
353        let store = Arc::new(InMemoryTokenStore::new());
354        store.save(&stored).await;
355
356        let strategy = OidcFederationStrategy::builder(
357            test_region(),
358            EXPECTED_WS.parse().unwrap(),
359            provider(),
360        )
361        .base_url(server.url(""))
362        .with_token_store(Arc::clone(&store))
363        .build()
364        .expect("builder");
365
366        let err = (&strategy)
367            .get_token()
368            .await
369            .expect_err("expected mismatch from stored token");
370        assert!(
371            matches!(err, AuthError::WorkspaceMismatch { .. }),
372            "expected WorkspaceMismatch, got {err:?}",
373        );
374    }
375
376    /// Regression guard — the workspace check runs on *every* `get_token()`
377    /// call, not only the one that triggers initial federation. A future
378    /// optimisation that cached the "verified" verdict would let a mismatched
379    /// token slide through on the second call.
380    #[tokio::test]
381    async fn errors_on_each_subsequent_get_token_call() {
382        const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
383        const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
384        let server = start_mock_server_returning_jwt(TOKEN_WS).await;
385
386        let strategy = OidcFederationStrategy::builder(
387            test_region(),
388            EXPECTED_WS.parse().unwrap(),
389            provider(),
390        )
391        .base_url(server.url(""))
392        .build()
393        .expect("builder");
394
395        for call in 1..=2 {
396            let err = match (&strategy).get_token().await {
397                Ok(_) => panic!("call {call}: expected Err, got Ok"),
398                Err(e) => e,
399            };
400            assert!(
401                matches!(err, AuthError::WorkspaceMismatch { .. }),
402                "call {call}: expected WorkspaceMismatch, got {err:?}",
403            );
404        }
405    }
406}