1use anyhow;
2use async_cache;
3use async_trait::async_trait;
4use faststr::FastStr;
5use jsonwebtoken::decode_header;
6pub use jsonwebtoken::errors::Error as JwtError;
7pub use jsonwebtoken::errors::ErrorKind as JwtErrorKind;
8use jsonwebtoken::{decode, Validation};
9pub use jsonwebtoken::{DecodingKey, EncodingKey};
10use jwks::Jwks;
11use lazy_static::lazy_static;
12use serde::Serialize;
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror;
16
17use super::identity::{IncomingClaims, SpacetimeIdentityClaims};
18use super::JwtKeys;
19
20#[derive(thiserror::Error, Debug)]
21pub enum TokenValidationError {
22 #[error("Invalid token: {0}")]
26 TokenError(#[from] JwtError),
27
28 #[error("Specified key ID not found in JWKs")]
29 KeyIDNotFound,
30
31 #[error(transparent)]
32 JwkError(#[from] jwks::JwkError),
33 #[error(transparent)]
34 JwksError(#[from] jwks::JwksError),
35 #[error(transparent)]
37 Other(#[from] anyhow::Error),
38}
39
40pub trait TokenSigner: Sync + Send {
42 fn sign<T: Serialize>(&self, claims: &T) -> Result<String, JwtError>;
44}
45
46impl TokenSigner for EncodingKey {
47 fn sign<Token: Serialize>(&self, claims: &Token) -> Result<String, JwtError> {
48 let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
49 jsonwebtoken::encode(&header, claims, self)
50 }
51}
52
53impl TokenSigner for JwtKeys {
54 fn sign<Token: Serialize>(&self, claims: &Token) -> Result<String, JwtError> {
55 self.private.sign(claims)
56 }
57}
58
59#[async_trait]
64pub trait TokenValidator {
65 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError>;
66}
67
68#[async_trait]
69impl<T: TokenValidator + Send + Sync> TokenValidator for Arc<T> {
70 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
71 (**self).validate_token(token).await
72 }
73}
74
75pub struct UnimplementedTokenValidator;
76
77#[async_trait]
78impl TokenValidator for UnimplementedTokenValidator {
79 async fn validate_token(&self, _token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
80 Err(TokenValidationError::Other(anyhow::anyhow!("Unimplemented")))
81 }
82}
83
84pub struct FullTokenValidator<T: TokenValidator + Send + Sync> {
88 pub local_key: DecodingKey,
89 pub local_issuer: String,
90 pub oidc_validator: T,
91}
92
93#[async_trait]
94impl<T> TokenValidator for FullTokenValidator<T>
95where
96 T: TokenValidator + Send + Sync,
97{
98 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
99 let local_key_error = {
100 let first_validator = BasicTokenValidator {
101 public_key: self.local_key.clone(),
102 issuer: None,
103 };
104 match first_validator.validate_token(token).await {
105 Ok(claims) => return Ok(claims),
106 Err(e) => e,
107 }
108 };
109
110 let issuer = get_raw_issuer(token)?;
112 if issuer == self.local_issuer {
115 return Err(local_key_error);
116 }
117 self.oidc_validator.validate_token(token).await
118 }
119}
120
121pub type DefaultValidator = FullTokenValidator<CachingOidcTokenValidator>;
122
123pub fn new_validator(local_key: DecodingKey, local_issuer: String) -> FullTokenValidator<CachingOidcTokenValidator> {
124 FullTokenValidator {
125 local_key,
126 local_issuer,
127 oidc_validator: CachingOidcTokenValidator::get_default(),
128 }
129}
130
131struct BasicTokenValidator {
135 pub public_key: DecodingKey,
136 pub issuer: Option<String>,
137}
138
139lazy_static! {
140 static ref REQUIRED_CLAIMS: Vec<&'static str> = vec!["sub", "iss"];
142}
143
144#[async_trait]
145impl TokenValidator for DecodingKey {
146 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
147 let mut validation = Validation::new(jsonwebtoken::Algorithm::ES256);
148 validation.algorithms = vec![
149 jsonwebtoken::Algorithm::ES256,
150 jsonwebtoken::Algorithm::RS256,
151 jsonwebtoken::Algorithm::HS256,
152 ];
153 validation.set_required_spec_claims(&REQUIRED_CLAIMS);
154
155 validation.validate_aud = false;
157
158 let data = decode::<IncomingClaims>(token, self, &validation)?;
159 let claims = data.claims;
160 claims.try_into().map_err(TokenValidationError::Other)
161 }
162}
163
164#[async_trait]
165impl TokenValidator for BasicTokenValidator {
166 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
167 let claims = self.public_key.validate_token(token).await?;
169 if let Some(expected_issuer) = &self.issuer {
170 if claims.issuer != *expected_issuer {
171 return Err(TokenValidationError::Other(anyhow::anyhow!(
172 "Issuer mismatch: got {:?}, expected {:?}",
173 claims.issuer,
174 expected_issuer
175 )));
176 }
177 }
178 Ok(claims)
179 }
180}
181
182pub struct CachingOidcTokenValidator {
184 cache: async_cache::AsyncCache<Arc<JwksValidator>, KeyFetcher>,
185}
186
187impl CachingOidcTokenValidator {
188 pub fn new(refresh_duration: Duration, expiry: Option<Duration>) -> Self {
189 let cache = async_cache::Options::new(refresh_duration, KeyFetcher)
190 .with_expire(expiry)
191 .build();
192 CachingOidcTokenValidator { cache }
193 }
194
195 pub fn get_default() -> Self {
196 Self::new(Duration::from_secs(300), Some(Duration::from_secs(7200)))
197 }
198}
199
200struct KeyFetcher;
202
203impl async_cache::Fetcher<Arc<JwksValidator>> for KeyFetcher {
204 type Error = TokenValidationError;
205
206 async fn fetch(&self, key: FastStr) -> Result<Arc<JwksValidator>, Self::Error> {
207 let raw_issuer = key.to_string();
209 log::info!("Fetching key for issuer {}", raw_issuer.clone());
210 let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer.trim_end_matches('/'));
211 let key_or_error = Jwks::from_oidc_url(oidc_url).await;
212 if let Err(e) = &key_or_error {
215 log::warn!("Error fetching public key for issuer {raw_issuer}: {e:?}");
216 }
217 let keys = key_or_error?;
218 let validator = JwksValidator {
219 issuer: raw_issuer.clone(),
220 keyset: keys,
221 };
222 Ok(Arc::new(validator))
223 }
224}
225
226#[async_trait]
227impl TokenValidator for CachingOidcTokenValidator {
228 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
229 let raw_issuer = get_raw_issuer(token)?;
230 log::debug!("Getting validator for issuer {}", raw_issuer.clone());
231 let validator = self
232 .cache
233 .get(raw_issuer.clone().into())
234 .await
235 .ok_or_else(|| anyhow::anyhow!("Error fetching public key for issuer {}", raw_issuer))?;
236 validator.validate_token(token).await
237 }
238}
239
240pub struct OidcTokenValidator;
244
245fn get_raw_issuer(token: &str) -> Result<String, TokenValidationError> {
247 let mut validation = Validation::new(jsonwebtoken::Algorithm::ES256);
248 validation.set_required_spec_claims(&REQUIRED_CLAIMS);
249 validation.validate_aud = false;
250 validation.insecure_disable_signature_validation();
252 let data = decode::<IncomingClaims>(token, &DecodingKey::from_secret(b"fake"), &validation)?;
253 Ok(data.claims.issuer)
254}
255
256#[async_trait]
257impl TokenValidator for OidcTokenValidator {
258 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
259 let raw_issuer = get_raw_issuer(token)?;
261 let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer.trim_end_matches('/'));
262 log::debug!("Fetching key for issuer {}", raw_issuer.clone());
263 let key_or_error = Jwks::from_oidc_url(oidc_url).await;
264 if let Err(e) = &key_or_error {
267 log::warn!("Error fetching public key for issuer {raw_issuer}: {e:?}");
268 }
269 let keys = key_or_error?;
270 let validator = JwksValidator {
271 issuer: raw_issuer,
272 keyset: keys,
273 };
274 validator.validate_token(token).await
275 }
276}
277
278struct JwksValidator {
279 pub issuer: String,
280 pub keyset: Jwks,
281}
282
283#[async_trait]
284impl TokenValidator for JwksValidator {
285 async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
286 let header = decode_header(token)?;
287 if let Some(kid) = header.kid {
288 let key = self
289 .keyset
290 .keys
291 .get(&kid)
292 .ok_or_else(|| TokenValidationError::KeyIDNotFound)?;
293 let validator = BasicTokenValidator {
294 public_key: key.decoding_key.clone(),
295 issuer: Some(self.issuer.clone()),
296 };
297 return validator.validate_token(token).await;
298 }
299 log::debug!("No key id in header. Trying all keys.");
300 let mut last_error = TokenValidationError::Other(anyhow::anyhow!("No kid found"));
303 for (kid, key) in &self.keyset.keys {
304 log::debug!("Trying key {kid}");
305 let validator = BasicTokenValidator {
306 public_key: key.decoding_key.clone(),
307 issuer: Some(self.issuer.clone()),
308 };
309 match validator.validate_token(token).await {
310 Ok(claims) => return Ok(claims),
311 Err(e) => {
312 last_error = e;
313 log::debug!("Validating with key {kid} failed");
314 continue;
315 }
316 }
317 }
318 Err(last_error)
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::time::Duration;
326
327 use crate::auth::identity::{IncomingClaims, SpacetimeIdentityClaims};
328 use crate::auth::token_validation::{
329 BasicTokenValidator, CachingOidcTokenValidator, FullTokenValidator, OidcTokenValidator, TokenSigner,
330 TokenValidator,
331 };
332 use crate::auth::JwtKeys;
333 use base64::Engine;
334 use openssl::ec::{EcGroup, EcKey};
335 use serde_json;
336 use spacetimedb_lib::Identity;
337
338 #[tokio::test]
339 async fn test_local_validator_checks_issuer() -> anyhow::Result<()> {
340 let kp = JwtKeys::generate()?;
342 let issuer = "test1";
343 let subject = "test_subject";
344
345 let orig_claims = IncomingClaims {
346 identity: None,
347 subject: subject.to_string(),
348 issuer: issuer.to_string(),
349 audience: vec![],
350 iat: std::time::SystemTime::now(),
351 exp: None,
352 };
353 let token = kp.private.sign(&orig_claims)?;
354
355 {
356 let validator = BasicTokenValidator {
358 public_key: kp.public.clone(),
359 issuer: Some(issuer.to_string()),
360 };
361
362 let parsed_claims: SpacetimeIdentityClaims = validator.validate_token(&token).await?;
363 assert_eq!(parsed_claims.issuer, issuer);
364 assert_eq!(parsed_claims.subject, subject);
365 assert_eq!(parsed_claims.identity, Identity::from_claims(issuer, subject));
366 }
367 {
368 let validator = BasicTokenValidator {
370 public_key: kp.public.clone(),
371 issuer: Some("otherissuer".to_string()),
372 };
373
374 assert!(validator.validate_token(&token).await.is_err());
375 }
376
377 Ok(())
378 }
379
380 #[tokio::test]
381 async fn test_local_validator_checks_key() -> anyhow::Result<()> {
382 let kp = JwtKeys::generate()?;
384 let issuer = "test1";
385 let subject = "test_subject";
386
387 let orig_claims = IncomingClaims {
388 identity: None,
389 subject: subject.to_string(),
390 issuer: issuer.to_string(),
391 audience: vec![],
392 iat: std::time::SystemTime::now(),
393 exp: None,
394 };
395 let token = kp.private.sign(&orig_claims)?;
396
397 {
398 let validator = BasicTokenValidator {
400 public_key: kp.public.clone(),
401 issuer: Some(issuer.to_string()),
402 };
403
404 let parsed_claims: SpacetimeIdentityClaims = validator.validate_token(&token).await?;
405 assert_eq!(parsed_claims.issuer, issuer);
406 assert_eq!(parsed_claims.subject, subject);
407 assert_eq!(parsed_claims.identity, Identity::from_claims(issuer, subject));
408 }
409 {
410 let other_kp = JwtKeys::generate()?;
412 let validator = BasicTokenValidator {
414 public_key: other_kp.public.clone(),
415 issuer: Some("otherissuer".to_string()),
416 };
417
418 assert!(validator.validate_token(&token).await.is_err());
419 }
420
421 Ok(())
422 }
423
424 async fn assert_validation_fails<T: TokenValidator>(validator: &T, token: &str) -> anyhow::Result<()> {
425 let result = validator.validate_token(token).await;
426 if result.is_ok() {
427 let claims = result.unwrap();
428 anyhow::bail!("Validation succeeded when it should have failed: {:?}", claims);
429 }
430 Ok(())
431 }
432
433 #[tokio::test]
434 async fn resigned_token_ignores_issuer() -> anyhow::Result<()> {
435 let kp = JwtKeys::generate()?;
437 let local_issuer = "test1";
438 let external_issuer = "other_issuer";
439 let subject = "test_subject";
440
441 let orig_claims = IncomingClaims {
442 identity: None,
443 subject: subject.to_string(),
444 issuer: external_issuer.to_string(),
445 audience: vec![],
446 iat: std::time::SystemTime::now(),
447 exp: None,
448 };
449 let token = kp.private.sign(&orig_claims)?;
450
451 {
453 let validator = FullTokenValidator {
454 local_key: kp.public.clone(),
455 local_issuer: local_issuer.to_string(),
456 oidc_validator: OidcTokenValidator,
457 };
458
459 let parsed_claims: SpacetimeIdentityClaims = validator.validate_token(&token).await?;
460 assert_eq!(parsed_claims.issuer, external_issuer);
461 assert_eq!(parsed_claims.subject, subject);
462 assert_eq!(parsed_claims.identity, Identity::from_claims(external_issuer, subject));
463 }
464 assert_validation_fails(&OidcTokenValidator, &token).await?;
466 assert_validation_fails(
468 &BasicTokenValidator {
469 public_key: kp.public.clone(),
470 issuer: Some(local_issuer.to_string()),
471 },
472 &token,
473 )
474 .await?;
475 Ok(())
476 }
477
478 use axum::routing::get;
479 use axum::Json;
480 use axum::Router;
481 use tokio::net::TcpListener;
482 use tokio::sync::oneshot;
483
484 use serde::{Deserialize, Serialize};
485 #[derive(Deserialize, Serialize, Clone)]
486 struct OIDCConfig {
487 jwks_uri: String,
488 }
489
490 async fn oidc_config_handler(config: OIDCConfig) -> Json<OIDCConfig> {
491 Json(config)
492 }
493
494 struct OIDCServerHandle {
498 pub base_url: String,
499 #[allow(dead_code)]
500 pub shutdown_tx: oneshot::Sender<()>,
501 #[allow(dead_code)]
502 join_handle: tokio::task::JoinHandle<()>,
503 }
504
505 impl OIDCServerHandle {
506 pub async fn start_new(jwks_json: String) -> anyhow::Result<Self> {
507 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
508 let addr = listener.local_addr()?;
509 let port = addr.port();
510 let base_url = format!("http://localhost:{port}");
511 let config = OIDCConfig {
512 jwks_uri: format!("{base_url}/jwks.json"),
513 };
514 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
515
516 let app = Router::new()
517 .route(
518 "/.well-known/openid-configuration",
519 get({
520 let config = config.clone();
521 move || oidc_config_handler(config.clone())
522 }),
523 )
524 .route(
525 "/jwks.json",
526 get({
527 let jwks = jwks_json.clone();
528 move || async move { jwks }
529 }),
530 )
531 .route("/ok", get(|| async move { "OK" }));
532
533 let join_handle = tokio::spawn(async move {
535 axum::serve(listener, app)
536 .with_graceful_shutdown(async {
537 shutdown_rx.await.ok();
538 })
539 .await
540 .unwrap();
541 });
542
543 let client = reqwest::Client::new();
545 let health_check_url = format!("{base_url}/ok");
546
547 let mut attempts = 0;
548 const MAX_ATTEMPTS: u32 = 10;
549 const DELAY_MS: u64 = 50;
550
551 while attempts < MAX_ATTEMPTS {
552 match client.get(&health_check_url).send().await {
553 Ok(response) if response.status().is_success() => break,
554 _ => {
555 log::debug!("Server not ready. Waiting...");
556 tokio::time::sleep(Duration::from_millis(DELAY_MS)).await;
557 attempts += 1;
558 }
559 }
560 }
561
562 if attempts == MAX_ATTEMPTS {
563 return Err(anyhow::anyhow!("Server failed to start after maximum attempts"));
564 }
565
566 Ok(OIDCServerHandle {
567 base_url,
568 shutdown_tx,
569 join_handle,
570 })
571 }
572 }
573
574 #[derive(Debug, Default, Copy, Clone)]
575 struct TestOptions {
576 pub issuer_trailing_slash: bool,
577 }
578
579 async fn run_oidc_test<T: TokenValidator>(validator: T, opts: &TestOptions) -> anyhow::Result<()> {
580 let mut kp1 = JwtKeys::generate()?;
582 let mut kp2 = JwtKeys::generate()?;
583
584 kp1.kid = Some("key1".to_string());
587 kp2.kid = Some("key2".to_string());
588
589 let invalid_kp = JwtKeys::generate()?;
591
592 let valid_keys: Vec<JwtKeys> = vec![kp1.clone(), kp2.clone()];
593 let jwks = keyset_to_json(valid_keys)?;
595
596 let handle = OIDCServerHandle::start_new(jwks).await?;
597
598 let issuer = handle.base_url.clone();
599 let issuer = if opts.issuer_trailing_slash {
600 format!("{issuer}/")
601 } else {
602 issuer
603 };
604 let subject = "test_subject";
605
606 let orig_claims = IncomingClaims {
607 identity: None,
608 subject: subject.to_string(),
609 issuer: issuer.clone(),
610 audience: vec![],
611 iat: std::time::SystemTime::now(),
612 exp: None,
613 };
614 for kp in [kp1, kp2] {
615 log::debug!("Testing with key {:?}", kp.kid);
616 let token = kp.private.sign(&orig_claims)?;
618
619 let validated_claims = validator.validate_token(&token).await?;
620 assert_eq!(validated_claims.issuer, issuer);
621 assert_eq!(validated_claims.subject, subject);
622 assert_eq!(validated_claims.identity, Identity::from_claims(&issuer, subject));
623 }
624
625 let invalid_token = invalid_kp.private.sign(&orig_claims)?;
626 assert!(validator.validate_token(&invalid_token).await.is_err());
627
628 Ok(())
629 }
630
631 #[tokio::test]
632 async fn test_oidc_flow() -> anyhow::Result<()> {
633 for _ in 0..10 {
634 run_oidc_test(OidcTokenValidator, &Default::default()).await?
635 }
636 Ok(())
637 }
638
639 #[tokio::test]
640 async fn test_issuer_slash() -> anyhow::Result<()> {
641 let opts = TestOptions {
642 issuer_trailing_slash: true,
643 };
644
645 run_oidc_test(OidcTokenValidator, &opts).await?;
646 run_oidc_test(CachingOidcTokenValidator::get_default(), &opts).await?;
647 Ok(())
648 }
649
650 #[tokio::test]
651 async fn test_caching_oidc_flow() -> anyhow::Result<()> {
652 for _ in 0..10 {
653 let v = CachingOidcTokenValidator::get_default();
654 run_oidc_test(v, &Default::default()).await?;
655 }
656 Ok(())
657 }
658
659 #[tokio::test]
660 async fn test_full_validator_fallback() -> anyhow::Result<()> {
661 let kp = JwtKeys::generate()?;
662 let v = FullTokenValidator {
663 local_key: kp.public,
664 local_issuer: "local_issuer".to_string(),
665 oidc_validator: OidcTokenValidator,
666 };
667 run_oidc_test(v, &Default::default()).await
668 }
669
670 fn keyset_to_json<I>(jks: I) -> anyhow::Result<String>
672 where
673 I: IntoIterator<Item = JwtKeys>,
674 {
675 let jks = jks
676 .into_iter()
677 .map(|key| to_jwk_json(&key).unwrap())
678 .collect::<Vec<serde_json::Value>>();
679
680 let j = serde_json::json!({
681 "keys": jks,
682 });
683 Ok(j.to_string())
684 }
685
686 fn to_jwk_json(jk: &JwtKeys) -> anyhow::Result<serde_json::Value> {
688 let eck = EcKey::public_key_from_pem(&jk.public_pem)?;
689
690 let group = EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1)?;
691 let mut ctx = openssl::bn::BigNumContext::new()?;
692
693 let mut x = openssl::bn::BigNum::new()?;
695 let mut y = openssl::bn::BigNum::new()?;
697 eck.public_key().affine_coordinates(&group, &mut x, &mut y, &mut ctx)?;
698
699 let x_bytes = x.to_vec();
700 let y_bytes = y.to_vec();
701
702 let x_padded = if x_bytes.len() < 32 {
703 let mut padded = vec![0u8; 32];
704 padded[32 - x_bytes.len()..].copy_from_slice(&x_bytes);
705 padded
706 } else {
707 x_bytes
708 };
709
710 let y_padded = if y_bytes.len() < 32 {
711 let mut padded = vec![0u8; 32];
712 padded[32 - y_bytes.len()..].copy_from_slice(&y_bytes);
713 padded
714 } else {
715 y_bytes
716 };
717 let x_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(x_padded);
718 let y_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(y_padded);
719
720 let mut jwks = serde_json::json!(
721 {
722 "kty": "EC",
723 "crv": "P-256",
724 "use": "sig",
725 "alg": "ES256",
726 "x": x_b64,
727 "y": y_b64
728 }
729 );
730 if let Some(kid) = &jk.kid {
731 jwks["kid"] = kid.to_string().into();
732 }
733 Ok(jwks)
734 }
735}