typesec_integrations/jwt/
authenticator.rs1use std::sync::{Arc, PoisonError, RwLock};
4use std::time::Instant;
5
6use jsonwebtoken::{
7 DecodingKey, TokenData, Validation, decode, decode_header,
8 jwk::{Jwk, JwkSet},
9};
10use typesec_core::typestate::{AgentError, Authenticator, Credentials};
11
12use crate::http::{HttpClient, ReqwestHttpClient};
13
14use super::claims::{JwtClaims, VerifiedSubject};
15use super::config::OidcConfig;
16
17pub struct JwtAuthenticator {
19 config: OidcConfig,
20 http: Arc<dyn HttpClient>,
21 jwks: RwLock<Option<CachedJwks>>,
22}
23
24#[derive(Clone)]
25struct CachedJwks {
26 keys: JwkSet,
27 fetched_at: Instant,
28}
29
30impl JwtAuthenticator {
31 pub fn new(config: OidcConfig) -> Self {
33 Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
34 }
35
36 pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
38 Self {
39 config,
40 http,
41 jwks: RwLock::new(None),
42 }
43 }
44
45 pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
47 let data = self.decode_claims(token)?;
48 if !data.claims.aud.contains(&self.config.audience) {
49 return Err(JwtAuthError::InvalidAudience);
50 }
51 Ok(data.claims.into())
52 }
53
54 fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
55 let header = decode_header(token)?;
56 let key = self.resolve_key(header.kid.as_deref())?;
57
58 let mut validation = Validation::default();
63 validation.algorithms = self.config.algorithms.clone();
64 validation.set_issuer(&[self.config.issuer.as_str()]);
65 validation.set_audience(&[self.config.audience.as_str()]);
66
67 Ok(decode::<JwtClaims>(
68 token,
69 &DecodingKey::from_jwk(&key)?,
70 &validation,
71 )?)
72 }
73
74 fn resolve_key(&self, kid: Option<&str>) -> Result<Jwk, JwtAuthError> {
82 let jwks = self.jwks(false)?;
83 match kid {
84 Some(kid) => {
85 if let Some(key) = jwks.find(kid) {
86 return Ok(key.clone());
87 }
88 let jwks = self.jwks(true)?;
90 jwks.find(kid).cloned().ok_or(JwtAuthError::MissingKey)
91 }
92 None => match jwks.keys.as_slice() {
93 [only] => Ok(only.clone()),
94 [] => Err(JwtAuthError::MissingKey),
95 _ => Err(JwtAuthError::MissingKid),
96 },
97 }
98 }
99
100 fn jwks(&self, force_refresh: bool) -> Result<JwkSet, JwtAuthError> {
101 if !force_refresh
105 && let Some(cached) = self
106 .jwks
107 .read()
108 .unwrap_or_else(PoisonError::into_inner)
109 .as_ref()
110 && cached.fetched_at.elapsed() < self.config.jwks_ttl
111 {
112 return Ok(cached.keys.clone());
113 }
114
115 let value = self.http.get_json(&self.config.jwks_url, &[])?;
116 let keys: JwkSet = serde_json::from_value(value)?;
117 *self.jwks.write().unwrap_or_else(PoisonError::into_inner) = Some(CachedJwks {
118 keys: keys.clone(),
119 fetched_at: Instant::now(),
120 });
121 Ok(keys)
122 }
123}
124
125impl Authenticator for JwtAuthenticator {
126 fn verify_credentials(&self, credentials: &Credentials) -> Result<String, AgentError> {
132 let verified =
133 self.verify(credentials.token.expose())
134 .map_err(|e| AgentError::AuthFailed {
135 reason: format!("jwt verification failed: {e}"),
136 })?;
137 if !credentials.subject.is_empty() && credentials.subject != verified.subject {
138 return Err(AgentError::AuthFailed {
139 reason: format!(
140 "claimed subject '{}' does not match verified token subject '{}'",
141 credentials.subject, verified.subject
142 ),
143 });
144 }
145 Ok(verified.subject)
146 }
147}
148
149#[derive(Debug, thiserror::Error)]
151pub enum JwtAuthError {
152 #[error("jwt validation failed: {0}")]
154 Jwt(#[from] jsonwebtoken::errors::Error),
155 #[error("jwks fetch failed: {0}")]
157 Http(#[from] Box<dyn std::error::Error + Send + Sync>),
158 #[error("jwks parse failed: {0}")]
160 Json(#[from] serde_json::Error),
161 #[error("no matching signing key found in JWKS")]
163 MissingKey,
164 #[error("token has no kid but JWKS is ambiguous (multiple keys)")]
166 MissingKid,
167 #[error("token audience did not match expected audience")]
169 InvalidAudience,
170}