1use web_time::{SystemTime, UNIX_EPOCH};
2
3use cts_common::claims::Claims;
4use cts_common::{Crn, Region, WorkspaceId};
5use url::Url;
6
7use crate::{http_client, AuthError, SecretToken};
8
9#[cfg(not(target_arch = "wasm32"))]
10impl stack_profile::ProfileData for Token {
11 const FILENAME: &'static str = "auth.json";
12 const MODE: Option<u32> = Some(0o600);
13}
14
15const EXPIRY_LEEWAY_SECS: u64 = 90;
21
22#[derive(Debug, serde::Serialize, serde::Deserialize)]
27pub struct Token {
28 pub(crate) access_token: SecretToken,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub(crate) refresh_token: Option<SecretToken>,
31 pub(crate) token_type: String,
32 pub(crate) expires_at: u64,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
34 pub(crate) region: Option<String>,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub(crate) client_id: Option<String>,
37 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub(crate) device_instance_id: Option<String>,
39}
40
41impl Token {
42 pub fn access_token(&self) -> &SecretToken {
47 &self.access_token
48 }
49
50 pub fn token_type(&self) -> &str {
52 &self.token_type
53 }
54
55 pub fn expires_at(&self) -> u64 {
57 self.expires_at
58 }
59
60 pub fn expires_in(&self) -> u64 {
62 let now = SystemTime::now()
63 .duration_since(UNIX_EPOCH)
64 .unwrap_or_default()
65 .as_secs();
66 self.expires_at.saturating_sub(now)
67 }
68
69 pub fn is_expired(&self) -> bool {
78 let now = SystemTime::now()
79 .duration_since(UNIX_EPOCH)
80 .unwrap_or_default()
81 .as_secs();
82 now + EXPIRY_LEEWAY_SECS >= self.expires_at
83 }
84
85 pub fn is_usable(&self) -> bool {
90 let now = SystemTime::now()
91 .duration_since(UNIX_EPOCH)
92 .unwrap_or_default()
93 .as_secs();
94 now < self.expires_at
95 }
96
97 pub fn refresh_token(&self) -> Option<&SecretToken> {
99 self.refresh_token.as_ref()
100 }
101
102 pub fn take_refresh_token(&mut self) -> Option<SecretToken> {
104 self.refresh_token.take()
105 }
106
107 pub fn region(&self) -> Option<&str> {
109 self.region.as_deref()
110 }
111
112 pub fn client_id(&self) -> Option<&str> {
114 self.client_id.as_deref()
115 }
116
117 pub(crate) fn set_region(&mut self, region: impl Into<String>) {
119 self.region = Some(region.into());
120 }
121
122 pub(crate) fn set_client_id(&mut self, client_id: impl Into<String>) {
124 self.client_id = Some(client_id.into());
125 }
126
127 pub fn device_instance_id(&self) -> Option<&str> {
129 self.device_instance_id.as_deref()
130 }
131
132 pub(crate) fn set_device_instance_id(&mut self, id: impl Into<String>) {
134 self.device_instance_id = Some(id.into());
135 }
136
137 pub fn workspace_id(&self) -> Result<WorkspaceId, AuthError> {
142 self.decode_claims().map(|c| c.workspace)
143 }
144
145 pub fn workspace_crn(&self) -> Result<Crn, AuthError> {
150 let workspace_id = self.workspace_id()?;
151 let region: Region = self
152 .region()
153 .ok_or(AuthError::NotAuthenticated)?
154 .parse()
155 .map_err(|e: cts_common::RegionError| AuthError::Server(e.to_string()))?;
156 Ok(Crn::new(region, workspace_id))
157 }
158
159 pub fn issuer(&self) -> Result<Url, AuthError> {
164 let claims = self.decode_claims()?;
165 claims.iss.parse().map_err(AuthError::from)
166 }
167
168 #[cfg(not(target_arch = "wasm32"))]
173 fn decode_claims(&self) -> Result<Claims, AuthError> {
174 use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
175 use std::collections::HashSet;
176
177 let token_str = self.access_token.as_str();
178 let header = decode_header(token_str)
179 .map_err(|e| AuthError::InvalidToken(format!("invalid JWT header: {e}")))?;
180
181 let dummy_key = DecodingKey::from_secret(&[]);
182 let mut validation = Validation::new(header.alg);
183 validation.validate_exp = false;
184 validation.validate_aud = false;
185 validation.required_spec_claims = HashSet::new();
186 validation.insecure_disable_signature_validation();
187
188 decode(token_str, &dummy_key, &validation)
189 .map(|data| data.claims)
190 .map_err(|e| AuthError::InvalidToken(format!("failed to decode JWT claims: {e}")))
191 }
192
193 #[cfg(target_arch = "wasm32")]
199 fn decode_claims(&self) -> Result<Claims, AuthError> {
200 crate::decode_jwt_payload_wasm(self.access_token.as_str())
201 }
202
203 pub async fn refresh(
218 refresh_token: &SecretToken,
219 base_url: &Url,
220 client_id: &str,
221 device_instance_id: Option<&str>,
222 ) -> Result<Token, AuthError> {
223 let token_url = base_url.join("oauth/token")?;
224
225 tracing::debug!(url = %token_url, "refreshing token");
226
227 let resp = http_client()
228 .post(token_url)
229 .form(&RefreshRequest {
230 grant_type: "refresh_token",
231 client_id,
232 refresh_token: refresh_token.as_str(),
233 device_instance_id,
234 })
235 .send()
236 .await?;
237
238 if !resp.status().is_success() {
239 let err: RefreshErrorResponse = resp.json().await?;
240 tracing::debug!(error = %err.error, "token refresh failed");
241 return Err(match err.error.as_str() {
242 "invalid_grant" => AuthError::InvalidGrant,
243 "invalid_client" => AuthError::InvalidClient,
244 "access_denied" => AuthError::AccessDenied,
245 _ => AuthError::Server(err.error_description),
246 });
247 }
248
249 let token_resp: RefreshResponse = resp.json().await?;
250 let now = SystemTime::now()
251 .duration_since(UNIX_EPOCH)
252 .unwrap_or_default()
253 .as_secs();
254
255 Ok(Token {
256 access_token: token_resp.access_token,
257 token_type: token_resp.token_type,
258 expires_at: now + token_resp.expires_in,
259 refresh_token: token_resp.refresh_token,
260 region: None,
261 client_id: None,
262 device_instance_id: None,
266 })
267 }
268}
269
270#[derive(serde::Serialize)]
271struct RefreshRequest<'a> {
272 grant_type: &'a str,
273 client_id: &'a str,
274 refresh_token: &'a str,
275 #[serde(skip_serializing_if = "Option::is_none")]
276 device_instance_id: Option<&'a str>,
277}
278
279#[derive(serde::Deserialize)]
280struct RefreshResponse {
281 access_token: SecretToken,
282 token_type: String,
283 expires_in: u64,
284 #[serde(default)]
285 refresh_token: Option<SecretToken>,
286}
287
288#[derive(serde::Deserialize)]
289struct RefreshErrorResponse {
290 error: String,
291 #[serde(default)]
292 error_description: String,
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::AuthError;
299 use mocktail::prelude::*;
300
301 fn make_token(expires_in: u64, refresh: bool) -> Token {
302 let now = SystemTime::now()
303 .duration_since(UNIX_EPOCH)
304 .unwrap()
305 .as_secs();
306
307 Token {
308 access_token: SecretToken::new("test-access-token"),
309 token_type: "Bearer".to_string(),
310 expires_at: now + expires_in,
311 refresh_token: if refresh {
312 Some(SecretToken::new("test-refresh-token"))
313 } else {
314 None
315 },
316 region: None,
317 client_id: None,
318 device_instance_id: None,
319 }
320 }
321
322 fn refresh_response_json() -> serde_json::Value {
323 serde_json::json!({
324 "access_token": "new-access-token",
325 "token_type": "Bearer",
326 "expires_in": 3600,
327 "refresh_token": "new-refresh-token"
328 })
329 }
330
331 fn error_json(error: &str) -> serde_json::Value {
332 serde_json::json!({
333 "error": error,
334 "error_description": format!("{error} occurred")
335 })
336 }
337
338 async fn start_server(mocks: MockSet) -> MockServer {
339 let server = MockServer::new_http("token-refresh-test").with_mocks(mocks);
340 server.start().await.unwrap();
341 server
342 }
343
344 #[test]
345 fn test_secret_token_debug_does_not_leak() {
346 let token = SecretToken("super_secret_value".to_string());
347 let debug = format!("{:?}", token);
348 assert!(
349 !debug.contains("super_secret_value"),
350 "SecretToken Debug should not contain the secret, got: {debug}"
351 );
352 }
353
354 #[tokio::test]
357 async fn test_refresh_success() {
358 let mut mocks = MockSet::new();
359 mocks.mock(|when, then| {
360 when.post().path("/oauth/token");
361 then.json(refresh_response_json());
362 });
363 let server = start_server(mocks).await;
364 let base_url = server.url("");
365
366 let refresh_token = SecretToken::new("test-refresh-token");
367 let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
368 .await
369 .unwrap();
370
371 assert_eq!(refreshed.access_token().as_str(), "new-access-token");
372 assert_eq!(refreshed.token_type(), "Bearer");
373 assert_eq!(
374 refreshed.refresh_token().unwrap().as_str(),
375 "new-refresh-token"
376 );
377 assert!(!refreshed.is_expired());
378 assert!((3598..=3600).contains(&refreshed.expires_in()));
379 }
380
381 #[tokio::test]
382 async fn test_refresh_invalid_grant() {
383 let mut mocks = MockSet::new();
384 mocks.mock(|when, then| {
385 when.post().path("/oauth/token");
386 then.bad_request().json(error_json("invalid_grant"));
387 });
388 let server = start_server(mocks).await;
389 let base_url = server.url("");
390
391 let refresh_token = SecretToken::new("test-refresh-token");
392 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
393 .await
394 .unwrap_err();
395
396 assert!(matches!(err, AuthError::InvalidGrant));
397 }
398
399 #[tokio::test]
400 async fn test_refresh_invalid_client() {
401 let mut mocks = MockSet::new();
402 mocks.mock(|when, then| {
403 when.post().path("/oauth/token");
404 then.bad_request().json(error_json("invalid_client"));
405 });
406 let server = start_server(mocks).await;
407 let base_url = server.url("");
408
409 let refresh_token = SecretToken::new("test-refresh-token");
410 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
411 .await
412 .unwrap_err();
413
414 assert!(matches!(err, AuthError::InvalidClient));
415 }
416
417 #[tokio::test]
418 async fn test_refresh_access_denied() {
419 let mut mocks = MockSet::new();
420 mocks.mock(|when, then| {
421 when.post().path("/oauth/token");
422 then.bad_request().json(error_json("access_denied"));
423 });
424 let server = start_server(mocks).await;
425 let base_url = server.url("");
426
427 let refresh_token = SecretToken::new("test-refresh-token");
428 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
429 .await
430 .unwrap_err();
431
432 assert!(matches!(err, AuthError::AccessDenied));
433 }
434
435 #[tokio::test]
436 async fn test_refresh_unknown_error() {
437 let mut mocks = MockSet::new();
438 mocks.mock(|when, then| {
439 when.post().path("/oauth/token");
440 then.bad_request().json(error_json("something_unexpected"));
441 });
442 let server = start_server(mocks).await;
443 let base_url = server.url("");
444
445 let refresh_token = SecretToken::new("test-refresh-token");
446 let err = Token::refresh(&refresh_token, &base_url, "cli", None)
447 .await
448 .unwrap_err();
449
450 assert!(matches!(&err, AuthError::Server(desc) if desc == "something_unexpected occurred"));
451 }
452
453 #[tokio::test]
454 async fn test_refresh_response_without_new_refresh_token() {
455 let mut mocks = MockSet::new();
456 mocks.mock(|when, then| {
457 when.post().path("/oauth/token");
458 then.json(serde_json::json!({
459 "access_token": "new-access-token",
460 "token_type": "Bearer",
461 "expires_in": 3600
462 }));
463 });
464 let server = start_server(mocks).await;
465 let base_url = server.url("");
466
467 let refresh_token = SecretToken::new("test-refresh-token");
468 let refreshed = Token::refresh(&refresh_token, &base_url, "cli", None)
469 .await
470 .unwrap();
471
472 assert_eq!(refreshed.access_token().as_str(), "new-access-token");
473 assert!(refreshed.refresh_token().is_none());
474 }
475
476 #[tokio::test]
477 async fn test_refresh_debug_does_not_leak_tokens() {
478 let token = make_token(3600, true);
479 let debug = format!("{:?}", token);
480 assert!(
481 !debug.contains("test-access-token"),
482 "Debug output should not contain access token, got: {debug}"
483 );
484 assert!(
485 !debug.contains("test-refresh-token"),
486 "Debug output should not contain refresh token, got: {debug}"
487 );
488 }
489
490 fn make_jwt_token(claims_json: serde_json::Value) -> Token {
495 use jsonwebtoken::{encode, EncodingKey, Header};
496 let jwt = encode(
497 &Header::default(),
498 &claims_json,
499 &EncodingKey::from_secret(b"test-secret"),
500 )
501 .expect("failed to encode JWT");
502
503 let now = SystemTime::now()
504 .duration_since(UNIX_EPOCH)
505 .unwrap()
506 .as_secs();
507
508 Token {
509 access_token: SecretToken::new(jwt),
510 token_type: "Bearer".to_string(),
511 expires_at: now + 3600,
512 refresh_token: None,
513 region: None,
514 client_id: None,
515 device_instance_id: None,
516 }
517 }
518
519 fn valid_claims_json() -> serde_json::Value {
520 serde_json::json!({
521 "workspace": "7366ITCXSAPCH5TN",
522 "iss": "https://cts.example.com",
523 "sub": "user-123",
524 "aud": "https://cts.example.com",
525 "iat": 1700000000u64,
526 "exp": 1700003600u64,
527 "scope": "dataset:create"
528 })
529 }
530
531 #[test]
532 fn test_workspace_id_extracts_from_jwt() {
533 let token = make_jwt_token(valid_claims_json());
534 let ws = token.workspace_id().expect("should extract workspace ID");
535 assert_eq!(ws.to_string(), "7366ITCXSAPCH5TN");
536 }
537
538 #[test]
539 fn test_issuer_extracts_url_from_jwt() {
540 let token = make_jwt_token(valid_claims_json());
541 let issuer = token.issuer().expect("should extract issuer");
542 assert_eq!(issuer.as_str(), "https://cts.example.com/");
543 }
544
545 #[test]
546 fn test_workspace_id_fails_on_invalid_jwt() {
547 let token = Token {
548 access_token: SecretToken::new("not-a-jwt"),
549 token_type: "Bearer".to_string(),
550 expires_at: 0,
551 refresh_token: None,
552 region: None,
553 client_id: None,
554 device_instance_id: None,
555 };
556 let err = token.workspace_id().unwrap_err();
557 assert!(matches!(err, AuthError::InvalidToken(_)));
558 }
559
560 #[test]
561 fn test_issuer_fails_on_missing_claims() {
562 let token = make_jwt_token(serde_json::json!({"sub": "user-123"}));
563 let err = token.issuer().unwrap_err();
564 assert!(matches!(err, AuthError::InvalidToken(_)));
565 }
566
567 #[test]
568 fn test_workspace_crn_derives_from_region_and_workspace() {
569 let mut token = make_jwt_token(valid_claims_json());
570 token.set_region("ap-southeast-2.aws");
571 let crn = token.workspace_crn().expect("should derive workspace CRN");
572 assert_eq!(crn.to_string(), "crn:ap-southeast-2.aws:7366ITCXSAPCH5TN");
573 }
574
575 #[test]
576 fn test_workspace_crn_fails_without_region() {
577 let token = make_jwt_token(valid_claims_json());
578 let err = token.workspace_crn().unwrap_err();
579 assert!(matches!(err, AuthError::NotAuthenticated));
580 }
581
582 #[test]
583 fn test_workspace_crn_fails_with_invalid_region() {
584 let mut token = make_jwt_token(valid_claims_json());
585 token.set_region("invalid-region");
586 let err = token.workspace_crn().unwrap_err();
587 assert!(matches!(err, AuthError::Server(_)));
588 }
589}