Skip to main content

stack_auth/
token.rs

1use web_time::{SystemTime, UNIX_EPOCH};
2
3use cts_common::claims::Claims;
4use cts_common::{Crn, Region, WorkspaceId};
5use url::Url;
6
7use crate::{http_client, AuthError, SecretToken};
8
9#[cfg(not(target_arch = "wasm32"))]
10impl stack_profile::ProfileData for Token {
11    const FILENAME: &'static str = "auth.json";
12    const MODE: Option<u32> = Some(0o600);
13}
14
15/// How many seconds before expiry [`Token::is_expired`] returns `true`.
16///
17/// This leeway triggers preemptive refresh well before the token becomes
18/// unusable, giving the HTTP refresh call time to complete while concurrent
19/// callers can still use the current token.
20const EXPIRY_LEEWAY_SECS: u64 = 90;
21
22/// An access token returned by a successful authentication flow.
23///
24/// The token contains a [`SecretToken`] (the bearer credential), a token type
25/// (typically `"Bearer"`), and an absolute expiry timestamp.
26#[derive(Debug, serde::Serialize, serde::Deserialize)]
27pub struct Token {
28    pub(crate) access_token: SecretToken,
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub(crate) refresh_token: Option<SecretToken>,
31    pub(crate) token_type: String,
32    pub(crate) expires_at: u64,
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub(crate) region: Option<String>,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub(crate) client_id: Option<String>,
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub(crate) device_instance_id: Option<String>,
39}
40
41impl Token {
42    /// Returns a reference to the access token credential.
43    ///
44    /// The returned [`SecretToken`] is opaque — its [`Debug`] output is masked.
45    /// Pass it to API clients that need the raw bearer token.
46    pub fn access_token(&self) -> &SecretToken {
47        &self.access_token
48    }
49
50    /// The token type (e.g. `"Bearer"`).
51    pub fn token_type(&self) -> &str {
52        &self.token_type
53    }
54
55    /// The absolute epoch timestamp when the token expires.
56    pub fn expires_at(&self) -> u64 {
57        self.expires_at
58    }
59
60    /// How many seconds until the token expires (computed from the current time).
61    pub fn expires_in(&self) -> u64 {
62        let now = SystemTime::now()
63            .duration_since(UNIX_EPOCH)
64            .unwrap_or_default()
65            .as_secs();
66        self.expires_at.saturating_sub(now)
67    }
68
69    /// Returns `true` if the token has expired (with 90 seconds of leeway).
70    ///
71    /// The 90-second leeway triggers preemptive refresh well before the token
72    /// becomes unusable, giving the HTTP refresh call plenty of time to complete
73    /// while the current token is still valid for concurrent callers.
74    ///
75    /// For checking whether the token is still usable as a bearer credential,
76    /// use [`is_usable`](Self::is_usable) instead.
77    pub fn is_expired(&self) -> bool {
78        let now = SystemTime::now()
79            .duration_since(UNIX_EPOCH)
80            .unwrap_or_default()
81            .as_secs();
82        now + EXPIRY_LEEWAY_SECS >= self.expires_at
83    }
84
85    /// Returns `true` if the token is still usable (before the actual expiry timestamp).
86    ///
87    /// Unlike [`is_expired`](Self::is_expired) which includes 90s leeway for preemptive
88    /// refresh, this only returns `false` when the token has genuinely expired.
89    pub fn is_usable(&self) -> bool {
90        let now = SystemTime::now()
91            .duration_since(UNIX_EPOCH)
92            .unwrap_or_default()
93            .as_secs();
94        now < self.expires_at
95    }
96
97    /// Returns a reference to the refresh token, if one was provided.
98    pub fn refresh_token(&self) -> Option<&SecretToken> {
99        self.refresh_token.as_ref()
100    }
101
102    /// Takes the refresh token out, leaving `None` in its place.
103    pub fn take_refresh_token(&mut self) -> Option<SecretToken> {
104        self.refresh_token.take()
105    }
106
107    /// Returns the stored region identifier, if any.
108    pub fn region(&self) -> Option<&str> {
109        self.region.as_deref()
110    }
111
112    /// Returns the stored client ID, if any.
113    pub fn client_id(&self) -> Option<&str> {
114        self.client_id.as_deref()
115    }
116
117    /// Set the region identifier on this token.
118    pub(crate) fn set_region(&mut self, region: impl Into<String>) {
119        self.region = Some(region.into());
120    }
121
122    /// Set the client ID on this token.
123    pub(crate) fn set_client_id(&mut self, client_id: impl Into<String>) {
124        self.client_id = Some(client_id.into());
125    }
126
127    /// Returns the stored device instance ID, if any.
128    pub fn device_instance_id(&self) -> Option<&str> {
129        self.device_instance_id.as_deref()
130    }
131
132    /// Set the device instance ID on this token.
133    pub(crate) fn set_device_instance_id(&mut self, id: impl Into<String>) {
134        self.device_instance_id = Some(id.into());
135    }
136
137    /// Returns the workspace ID from the JWT claims.
138    ///
139    /// The access token is decoded (without signature verification) to extract
140    /// the `workspace` claim.
141    pub fn workspace_id(&self) -> Result<WorkspaceId, AuthError> {
142        self.decode_claims().map(|c| c.workspace)
143    }
144
145    /// Returns the workspace CRN derived from the token's region and workspace ID.
146    ///
147    /// The region is set during the device code flow, and the workspace ID is
148    /// extracted from the JWT `workspace` claim.
149    pub fn workspace_crn(&self) -> Result<Crn, AuthError> {
150        let workspace_id = self.workspace_id()?;
151        let region: Region = self
152            .region()
153            .ok_or(AuthError::NotAuthenticated)?
154            .parse()
155            .map_err(|e: cts_common::RegionError| AuthError::Server(e.to_string()))?;
156        Ok(Crn::new(region, workspace_id))
157    }
158
159    /// Returns the issuer URL from the JWT claims.
160    ///
161    /// The `iss` claim in CipherStash tokens is the CTS host URL for the
162    /// workspace, so this can be used directly as the CTS base URL.
163    pub fn issuer(&self) -> Result<Url, AuthError> {
164        let claims = self.decode_claims()?;
165        claims.iss.parse().map_err(AuthError::from)
166    }
167
168    /// Decode the JWT payload into [`Claims`] without verifying the signature.
169    ///
170    /// This is safe because we already possess the token — we just need to read
171    /// the claims it contains.
172    #[cfg(not(target_arch = "wasm32"))]
173    fn decode_claims(&self) -> Result<Claims, AuthError> {
174        use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
175        use std::collections::HashSet;
176
177        let token_str = self.access_token.as_str();
178        let header = decode_header(token_str)
179            .map_err(|e| AuthError::InvalidToken(format!("invalid JWT header: {e}")))?;
180
181        let dummy_key = DecodingKey::from_secret(&[]);
182        let mut validation = Validation::new(header.alg);
183        validation.validate_exp = false;
184        validation.validate_aud = false;
185        validation.required_spec_claims = HashSet::new();
186        validation.insecure_disable_signature_validation();
187
188        decode(token_str, &dummy_key, &validation)
189            .map(|data| data.claims)
190            .map_err(|e| AuthError::InvalidToken(format!("failed to decode JWT claims: {e}")))
191    }
192
193    /// Wasm32 path: decode the JWT payload by splitting + base64 + JSON. We
194    /// don't need the cryptographic backing of `jsonwebtoken` (which pulls
195    /// `ring`) because we only ever read claims from a token we already hold;
196    /// signature validation is `insecure_disable_signature_validation()` on
197    /// native too.
198    #[cfg(target_arch = "wasm32")]
199    fn decode_claims(&self) -> Result<Claims, AuthError> {
200        crate::decode_jwt_payload_wasm(self.access_token.as_str())
201    }
202
203    /// Exchange a refresh token for a new [`Token`] via the `/oauth/token`
204    /// endpoint.
205    ///
206    /// This is a static constructor — it takes a bare [`SecretToken`] (the
207    /// refresh token) rather than operating on an existing `Token`. This
208    /// allows callers to manage the refresh token lifecycle independently
209    /// (e.g. taking it out of a cached token for cascade prevention and
210    /// restoring it on failure).
211    ///
212    /// # Errors
213    ///
214    /// - [`AuthError::InvalidGrant`] — the refresh token was revoked or expired.
215    /// - [`AuthError::InvalidClient`] — the client ID is not recognized.
216    /// - [`AuthError::Request`] — a network error occurred.
217    pub async fn refresh(
218        refresh_token: &SecretToken,
219        base_url: &Url,
220        client_id: &str,
221        device_instance_id: Option<&str>,
222    ) -> Result<Token, AuthError> {
223        let token_url = base_url.join("oauth/token")?;
224
225        tracing::debug!(url = %token_url, "refreshing token");
226
227        let resp = http_client()
228            .post(token_url)
229            .form(&RefreshRequest {
230                grant_type: "refresh_token",
231                client_id,
232                refresh_token: refresh_token.as_str(),
233                device_instance_id,
234            })
235            .send()
236            .await?;
237
238        if !resp.status().is_success() {
239            let err: RefreshErrorResponse = resp.json().await?;
240            tracing::debug!(error = %err.error, "token refresh failed");
241            return Err(match err.error.as_str() {
242                "invalid_grant" => AuthError::InvalidGrant,
243                "invalid_client" => AuthError::InvalidClient,
244                "access_denied" => AuthError::AccessDenied,
245                _ => AuthError::Server(err.error_description),
246            });
247        }
248
249        let token_resp: RefreshResponse = resp.json().await?;
250        let now = SystemTime::now()
251            .duration_since(UNIX_EPOCH)
252            .unwrap_or_default()
253            .as_secs();
254
255        Ok(Token {
256            access_token: token_resp.access_token,
257            token_type: token_resp.token_type,
258            expires_at: now + token_resp.expires_in,
259            refresh_token: token_resp.refresh_token,
260            region: None,
261            client_id: None,
262            // TODO(CIP-2793): The server should include device_instance_id in the
263            // refresh response. Until then, callers (e.g. OAuthRefresher) must
264            // re-attach it manually after refresh.
265            device_instance_id: None,
266        })
267    }
268}
269
270#[derive(serde::Serialize)]
271struct RefreshRequest<'a> {
272    grant_type: &'a str,
273    client_id: &'a str,
274    refresh_token: &'a str,
275    #[serde(skip_serializing_if = "Option::is_none")]
276    device_instance_id: Option<&'a str>,
277}
278
279#[derive(serde::Deserialize)]
280struct RefreshResponse {
281    access_token: SecretToken,
282    token_type: String,
283    expires_in: u64,
284    #[serde(default)]
285    refresh_token: Option<SecretToken>,
286}
287
288#[derive(serde::Deserialize)]
289struct RefreshErrorResponse {
290    error: String,
291    #[serde(default)]
292    error_description: String,
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::AuthError;
299    use mocktail::prelude::*;
300
301    fn make_token(expires_in: u64, refresh: bool) -> Token {
302        let now = SystemTime::now()
303            .duration_since(UNIX_EPOCH)
304            .unwrap()
305            .as_secs();
306
307        Token {
308            access_token: SecretToken::new("test-access-token"),
309            token_type: "Bearer".to_string(),
310            expires_at: now + expires_in,
311            refresh_token: if refresh {
312                Some(SecretToken::new("test-refresh-token"))
313            } else {
314                None
315            },
316            region: None,
317            client_id: None,
318            device_instance_id: None,
319        }
320    }
321
322    fn refresh_response_json() -> serde_json::Value {
323        serde_json::json!({
324            "access_token": "new-access-token",
325            "token_type": "Bearer",
326            "expires_in": 3600,
327            "refresh_token": "new-refresh-token"
328        })
329    }
330
331    fn error_json(error: &str) -> serde_json::Value {
332        serde_json::json!({
333            "error": error,
334            "error_description": format!("{error} occurred")
335        })
336    }
337
338    async fn start_server(mocks: MockSet) -> MockServer {
339        let server = MockServer::new_http("token-refresh-test").with_mocks(mocks);
340        server.start().await.unwrap();
341        server
342    }
343
344    #[test]
345    fn test_secret_token_debug_does_not_leak() {
346        let token = SecretToken("super_secret_value".to_string());
347        let debug = format!("{:?}", token);
348        assert!(
349            !debug.contains("super_secret_value"),
350            "SecretToken Debug should not contain the secret, got: {debug}"
351        );
352    }
353
354    // ---- refresh() tests ----
355
356    #[tokio::test]
357    async fn test_refresh_success() {
358        let mut mocks = MockSet::new();
359        mocks.mock(|when, then| {
360            when.post().path("/oauth/token");
361            then.json(refresh_response_json());
362        });
363        let server = start_server(mocks).await;
364        let base_url = server.url("");
365
366        let refresh_token = SecretToken::new("test-refresh-token");
367        let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
368            .await
369            .unwrap();
370
371        assert_eq!(refreshed.access_token().as_str(), "new-access-token");
372        assert_eq!(refreshed.token_type(), "Bearer");
373        assert_eq!(
374            refreshed.refresh_token().unwrap().as_str(),
375            "new-refresh-token"
376        );
377        assert!(!refreshed.is_expired());
378        assert!((3598..=3600).contains(&refreshed.expires_in()));
379    }
380
381    #[tokio::test]
382    async fn test_refresh_invalid_grant() {
383        let mut mocks = MockSet::new();
384        mocks.mock(|when, then| {
385            when.post().path("/oauth/token");
386            then.bad_request().json(error_json("invalid_grant"));
387        });
388        let server = start_server(mocks).await;
389        let base_url = server.url("");
390
391        let refresh_token = SecretToken::new("test-refresh-token");
392        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
393            .await
394            .unwrap_err();
395
396        assert!(matches!(err, AuthError::InvalidGrant));
397    }
398
399    #[tokio::test]
400    async fn test_refresh_invalid_client() {
401        let mut mocks = MockSet::new();
402        mocks.mock(|when, then| {
403            when.post().path("/oauth/token");
404            then.bad_request().json(error_json("invalid_client"));
405        });
406        let server = start_server(mocks).await;
407        let base_url = server.url("");
408
409        let refresh_token = SecretToken::new("test-refresh-token");
410        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
411            .await
412            .unwrap_err();
413
414        assert!(matches!(err, AuthError::InvalidClient));
415    }
416
417    #[tokio::test]
418    async fn test_refresh_access_denied() {
419        let mut mocks = MockSet::new();
420        mocks.mock(|when, then| {
421            when.post().path("/oauth/token");
422            then.bad_request().json(error_json("access_denied"));
423        });
424        let server = start_server(mocks).await;
425        let base_url = server.url("");
426
427        let refresh_token = SecretToken::new("test-refresh-token");
428        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
429            .await
430            .unwrap_err();
431
432        assert!(matches!(err, AuthError::AccessDenied));
433    }
434
435    #[tokio::test]
436    async fn test_refresh_unknown_error() {
437        let mut mocks = MockSet::new();
438        mocks.mock(|when, then| {
439            when.post().path("/oauth/token");
440            then.bad_request().json(error_json("something_unexpected"));
441        });
442        let server = start_server(mocks).await;
443        let base_url = server.url("");
444
445        let refresh_token = SecretToken::new("test-refresh-token");
446        let err = Token::refresh(&refresh_token, &base_url, "cli", None)
447            .await
448            .unwrap_err();
449
450        assert!(matches!(&err, AuthError::Server(desc) if desc == "something_unexpected occurred"));
451    }
452
453    #[tokio::test]
454    async fn test_refresh_response_without_new_refresh_token() {
455        let mut mocks = MockSet::new();
456        mocks.mock(|when, then| {
457            when.post().path("/oauth/token");
458            then.json(serde_json::json!({
459                "access_token": "new-access-token",
460                "token_type": "Bearer",
461                "expires_in": 3600
462            }));
463        });
464        let server = start_server(mocks).await;
465        let base_url = server.url("");
466
467        let refresh_token = SecretToken::new("test-refresh-token");
468        let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
469            .await
470            .unwrap();
471
472        assert_eq!(refreshed.access_token().as_str(), "new-access-token");
473        assert!(refreshed.refresh_token().is_none());
474    }
475
476    #[tokio::test]
477    async fn test_refresh_debug_does_not_leak_tokens() {
478        let token = make_token(3600, true);
479        let debug = format!("{:?}", token);
480        assert!(
481            !debug.contains("test-access-token"),
482            "Debug output should not contain access token, got: {debug}"
483        );
484        assert!(
485            !debug.contains("test-refresh-token"),
486            "Debug output should not contain refresh token, got: {debug}"
487        );
488    }
489
490    // ---- decode_claims / workspace_id / issuer tests ----
491
492    /// Build a Token whose access_token is a real (unsigned) JWT containing the
493    /// given claims JSON.
494    fn make_jwt_token(claims_json: serde_json::Value) -> Token {
495        use jsonwebtoken::{encode, EncodingKey, Header};
496        let jwt = encode(
497            &Header::default(),
498            &claims_json,
499            &EncodingKey::from_secret(b"test-secret"),
500        )
501        .expect("failed to encode JWT");
502
503        let now = SystemTime::now()
504            .duration_since(UNIX_EPOCH)
505            .unwrap()
506            .as_secs();
507
508        Token {
509            access_token: SecretToken::new(jwt),
510            token_type: "Bearer".to_string(),
511            expires_at: now + 3600,
512            refresh_token: None,
513            region: None,
514            client_id: None,
515            device_instance_id: None,
516        }
517    }
518
519    fn valid_claims_json() -> serde_json::Value {
520        serde_json::json!({
521            "workspace": "7366ITCXSAPCH5TN",
522            "iss": "https://cts.example.com",
523            "sub": "user-123",
524            "aud": "https://cts.example.com",
525            "iat": 1700000000u64,
526            "exp": 1700003600u64,
527            "scope": "dataset:create"
528        })
529    }
530
531    #[test]
532    fn test_workspace_id_extracts_from_jwt() {
533        let token = make_jwt_token(valid_claims_json());
534        let ws = token.workspace_id().expect("should extract workspace ID");
535        assert_eq!(ws.to_string(), "7366ITCXSAPCH5TN");
536    }
537
538    #[test]
539    fn test_issuer_extracts_url_from_jwt() {
540        let token = make_jwt_token(valid_claims_json());
541        let issuer = token.issuer().expect("should extract issuer");
542        assert_eq!(issuer.as_str(), "https://cts.example.com/");
543    }
544
545    #[test]
546    fn test_workspace_id_fails_on_invalid_jwt() {
547        let token = Token {
548            access_token: SecretToken::new("not-a-jwt"),
549            token_type: "Bearer".to_string(),
550            expires_at: 0,
551            refresh_token: None,
552            region: None,
553            client_id: None,
554            device_instance_id: None,
555        };
556        let err = token.workspace_id().unwrap_err();
557        assert!(matches!(err, AuthError::InvalidToken(_)));
558    }
559
560    #[test]
561    fn test_issuer_fails_on_missing_claims() {
562        let token = make_jwt_token(serde_json::json!({"sub": "user-123"}));
563        let err = token.issuer().unwrap_err();
564        assert!(matches!(err, AuthError::InvalidToken(_)));
565    }
566
567    #[test]
568    fn test_workspace_crn_derives_from_region_and_workspace() {
569        let mut token = make_jwt_token(valid_claims_json());
570        token.set_region("ap-southeast-2.aws");
571        let crn = token.workspace_crn().expect("should derive workspace CRN");
572        assert_eq!(crn.to_string(), "crn:ap-southeast-2.aws:7366ITCXSAPCH5TN");
573    }
574
575    #[test]
576    fn test_workspace_crn_fails_without_region() {
577        let token = make_jwt_token(valid_claims_json());
578        let err = token.workspace_crn().unwrap_err();
579        assert!(matches!(err, AuthError::NotAuthenticated));
580    }
581
582    #[test]
583    fn test_workspace_crn_fails_with_invalid_region() {
584        let mut token = make_jwt_token(valid_claims_json());
585        token.set_region("invalid-region");
586        let err = token.workspace_crn().unwrap_err();
587        assert!(matches!(err, AuthError::Server(_)));
588    }
589}