1use std::collections::HashSet;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::{
8 Algorithm, DecodingKey, TokenData, Validation, decode, decode_header,
9 jwk::{Jwk, JwkSet},
10};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tracing::debug;
14use typesec_core::policy::{PolicyEngine, PolicyResult};
15use typesec_core::typestate::{AgentError, Authenticator, Credentials};
16
17use crate::http::{HttpClient, ReqwestHttpClient};
18
19#[derive(Debug, Clone)]
21pub struct OidcConfig {
22 pub issuer: String,
24 pub audience: String,
26 pub jwks_url: String,
28 pub algorithms: Vec<Algorithm>,
30 pub jwks_ttl: Duration,
35}
36
37impl OidcConfig {
38 pub fn new(
40 issuer: impl Into<String>,
41 audience: impl Into<String>,
42 jwks_url: impl Into<String>,
43 ) -> Self {
44 Self {
45 issuer: issuer.into(),
46 audience: audience.into(),
47 jwks_url: jwks_url.into(),
48 algorithms: vec![Algorithm::RS256],
49 jwks_ttl: Duration::from_secs(300),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct JwtClaims {
57 pub sub: String,
59 pub iss: String,
61 pub aud: Audience,
63 pub exp: usize,
65 #[serde(default)]
67 pub org_id: Option<String>,
68 #[serde(default)]
70 pub organization_membership_id: Option<String>,
71 #[serde(default)]
73 pub role: Option<String>,
74 #[serde(default)]
76 pub permissions: Vec<String>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(untagged)]
82pub enum Audience {
83 Single(String),
85 Multiple(Vec<String>),
87}
88
89impl Audience {
90 fn contains(&self, needle: &str) -> bool {
91 match self {
92 Self::Single(value) => value == needle,
93 Self::Multiple(values) => values.iter().any(|value| value == needle),
94 }
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
100pub struct VerifiedSubject {
101 pub subject: String,
103 pub org_id: Option<String>,
105 pub organization_membership_id: Option<String>,
107 pub roles: Vec<String>,
109 pub permissions: Vec<String>,
111}
112
113impl VerifiedSubject {
114 pub fn workos_membership_subject(&self) -> &str {
116 self.organization_membership_id
117 .as_deref()
118 .unwrap_or(&self.subject)
119 }
120}
121
122impl From<JwtClaims> for VerifiedSubject {
123 fn from(claims: JwtClaims) -> Self {
124 Self {
125 subject: claims.sub,
126 org_id: claims.org_id,
127 organization_membership_id: claims.organization_membership_id,
128 roles: claims.role.into_iter().collect(),
129 permissions: claims.permissions,
130 }
131 }
132}
133
134pub struct JwtAuthenticator {
136 config: OidcConfig,
137 http: Arc<dyn HttpClient>,
138 jwks: RwLock<Option<CachedJwks>>,
139}
140
141#[derive(Clone)]
142struct CachedJwks {
143 keys: JwkSet,
144 fetched_at: Instant,
145}
146
147impl JwtAuthenticator {
148 pub fn new(config: OidcConfig) -> Self {
150 Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
151 }
152
153 pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
155 Self {
156 config,
157 http,
158 jwks: RwLock::new(None),
159 }
160 }
161
162 pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
164 let data = self.decode_claims(token)?;
165 if !data.claims.aud.contains(&self.config.audience) {
166 return Err(JwtAuthError::InvalidAudience);
167 }
168 Ok(data.claims.into())
169 }
170
171 fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
172 let header = decode_header(token)?;
173 let key = self.resolve_key(header.kid.as_deref())?;
174
175 let mut validation = Validation::new(header.alg);
176 validation.algorithms = self.config.algorithms.clone();
177 validation.set_issuer(&[self.config.issuer.as_str()]);
178 validation.set_audience(&[self.config.audience.as_str()]);
179
180 Ok(decode::<JwtClaims>(
181 token,
182 &DecodingKey::from_jwk(&key)?,
183 &validation,
184 )?)
185 }
186
187 fn resolve_key(&self, kid: Option<&str>) -> Result<Jwk, JwtAuthError> {
195 let jwks = self.jwks(false)?;
196 match kid {
197 Some(kid) => {
198 if let Some(key) = jwks.find(kid) {
199 return Ok(key.clone());
200 }
201 let jwks = self.jwks(true)?;
203 jwks.find(kid).cloned().ok_or(JwtAuthError::MissingKey)
204 }
205 None => match jwks.keys.as_slice() {
206 [only] => Ok(only.clone()),
207 [] => Err(JwtAuthError::MissingKey),
208 _ => Err(JwtAuthError::MissingKid),
209 },
210 }
211 }
212
213 fn jwks(&self, force_refresh: bool) -> Result<JwkSet, JwtAuthError> {
214 if !force_refresh
215 && let Some(cached) = self.jwks.read().expect("jwks lock poisoned").as_ref()
216 && cached.fetched_at.elapsed() < self.config.jwks_ttl
217 {
218 return Ok(cached.keys.clone());
219 }
220
221 let value = self.http.get_json(&self.config.jwks_url, &[])?;
222 let keys: JwkSet = serde_json::from_value(value)?;
223 *self.jwks.write().expect("jwks lock poisoned") = Some(CachedJwks {
224 keys: keys.clone(),
225 fetched_at: Instant::now(),
226 });
227 Ok(keys)
228 }
229}
230
231impl Authenticator for JwtAuthenticator {
232 fn verify_credentials(&self, credentials: &Credentials) -> Result<String, AgentError> {
238 let verified =
239 self.verify(credentials.token.expose())
240 .map_err(|e| AgentError::AuthFailed {
241 reason: format!("jwt verification failed: {e}"),
242 })?;
243 if !credentials.subject.is_empty() && credentials.subject != verified.subject {
244 return Err(AgentError::AuthFailed {
245 reason: format!(
246 "claimed subject '{}' does not match verified token subject '{}'",
247 credentials.subject, verified.subject
248 ),
249 });
250 }
251 Ok(verified.subject)
252 }
253}
254
255#[derive(Debug, thiserror::Error)]
257pub enum JwtAuthError {
258 #[error("jwt validation failed: {0}")]
260 Jwt(#[from] jsonwebtoken::errors::Error),
261 #[error("jwks fetch failed: {0}")]
263 Http(#[from] Box<dyn std::error::Error + Send + Sync>),
264 #[error("jwks parse failed: {0}")]
266 Json(#[from] serde_json::Error),
267 #[error("no matching signing key found in JWKS")]
269 MissingKey,
270 #[error("token has no kid but JWKS is ambiguous (multiple keys)")]
272 MissingKid,
273 #[error("token audience did not match expected audience")]
275 InvalidAudience,
276}
277
278pub struct JwtClaimsEngine {
284 subject: String,
285 permissions: HashSet<String>,
286 org_id: Option<String>,
287}
288
289impl JwtClaimsEngine {
290 pub fn new(subject: VerifiedSubject) -> Self {
292 Self {
293 subject: subject.subject,
294 permissions: subject.permissions.into_iter().collect(),
295 org_id: subject.org_id,
296 }
297 }
298
299 pub fn from_permissions(
301 subject: impl Into<String>,
302 permissions: impl IntoIterator<Item = String>,
303 ) -> Self {
304 Self {
305 subject: subject.into(),
306 permissions: permissions.into_iter().collect(),
307 org_id: None,
308 }
309 }
310
311 fn permission_matches(&self, action: &str, resource: &str) -> bool {
312 if self.permissions.contains(action) {
313 return true;
314 }
315
316 let resource_type = resource.split(['/', ':']).next().unwrap_or(resource);
317 self.permissions
318 .contains(&format!("{resource_type}:{action}"))
319 }
320}
321
322impl PolicyEngine for JwtClaimsEngine {
323 fn check(&self, subject: &str, action: &str, resource: &str) -> PolicyResult {
324 debug!(subject, action, resource, org_id = ?self.org_id, "jwt claims check");
325
326 if subject != self.subject {
327 return PolicyResult::Delegate(format!(
328 "jwt claims are for '{}', not '{subject}'",
329 self.subject
330 ));
331 }
332
333 if self.permission_matches(action, resource) {
334 PolicyResult::Allow
335 } else {
336 PolicyResult::Delegate(format!("permission '{action}' not present in jwt claims"))
337 }
338 }
339}
340
341#[allow(dead_code)]
342fn _assert_value_send_sync(_: Value) {}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::http::StaticHttpClient;
348 use chrono::{Duration, Utc};
349 use jsonwebtoken::{EncodingKey, Header, encode};
350 use serde_json::json;
351
352 #[test]
353 fn jwt_claims_engine_allows_direct_permission() {
354 let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
355 assert_eq!(
356 engine.check("user_1", "read", "project/123"),
357 PolicyResult::Allow
358 );
359 }
360
361 #[test]
362 fn jwt_claims_engine_allows_resource_type_permission() {
363 let engine = JwtClaimsEngine::from_permissions("user_1", ["project:edit".to_string()]);
364 assert_eq!(
365 engine.check("user_1", "edit", "project/123"),
366 PolicyResult::Allow
367 );
368 }
369
370 #[test]
371 fn jwt_claims_engine_delegates_missing_permission() {
372 let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
373 assert!(matches!(
374 engine.check("user_1", "write", "project/123"),
375 PolicyResult::Delegate(_)
376 ));
377 }
378
379 #[test]
380 fn jwt_authenticator_verifies_hs256_token_from_jwks() {
381 let jwks_url = "https://issuer.example/.well-known/jwks.json";
382 let http = StaticHttpClient::new().with_response(
383 jwks_url,
384 json!({
385 "keys": [{
386 "kty": "oct",
387 "kid": "test-key",
388 "alg": "HS256",
389 "k": "c2VjcmV0"
390 }]
391 }),
392 );
393 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
394 config.algorithms = vec![Algorithm::HS256];
395 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
396
397 let claims = JwtClaims {
398 sub: "user_123".to_string(),
399 iss: "https://issuer.example".to_string(),
400 aud: Audience::Single("typesec-test".to_string()),
401 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
402 org_id: Some("org_123".to_string()),
403 organization_membership_id: Some("om_123".to_string()),
404 role: Some("org_member".to_string()),
405 permissions: vec!["org:view".to_string(), "project:read".to_string()],
406 };
407 let mut header = Header::new(Algorithm::HS256);
408 header.kid = Some("test-key".to_string());
409 let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
410 .expect("token should encode");
411
412 let verified = auth.verify(&token).expect("token should verify");
413 assert_eq!(verified.subject, "user_123");
414 assert_eq!(verified.workos_membership_subject(), "om_123");
415 assert_eq!(verified.permissions, vec!["org:view", "project:read"]);
416 }
417
418 #[test]
419 fn jwt_authenticator_rejects_wrong_audience() {
420 let jwks_url = "https://issuer.example/.well-known/jwks.json";
421 let http = StaticHttpClient::new().with_response(
422 jwks_url,
423 json!({
424 "keys": [{
425 "kty": "oct",
426 "kid": "test-key",
427 "alg": "HS256",
428 "k": "c2VjcmV0"
429 }]
430 }),
431 );
432 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
433 config.algorithms = vec![Algorithm::HS256];
434 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
435
436 let claims = JwtClaims {
437 sub: "user_123".to_string(),
438 iss: "https://issuer.example".to_string(),
439 aud: Audience::Single("other-audience".to_string()),
440 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
441 org_id: None,
442 organization_membership_id: None,
443 role: None,
444 permissions: vec![],
445 };
446 let mut header = Header::new(Algorithm::HS256);
447 header.kid = Some("test-key".to_string());
448 let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
449 .expect("token should encode");
450
451 assert!(auth.verify(&token).is_err());
452 }
453
454 fn hs256_config_and_jwks(jwks_url: &str) -> (OidcConfig, serde_json::Value) {
455 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
456 config.algorithms = vec![Algorithm::HS256];
457 let jwks = json!({
458 "keys": [{
459 "kty": "oct",
460 "kid": "test-key",
461 "alg": "HS256",
462 "k": "c2VjcmV0"
463 }]
464 });
465 (config, jwks)
466 }
467
468 fn hs256_token(kid: Option<&str>) -> String {
469 let claims = JwtClaims {
470 sub: "user_123".to_string(),
471 iss: "https://issuer.example".to_string(),
472 aud: Audience::Single("typesec-test".to_string()),
473 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
474 org_id: None,
475 organization_membership_id: None,
476 role: None,
477 permissions: vec![],
478 };
479 let mut header = Header::new(Algorithm::HS256);
480 header.kid = kid.map(str::to_owned);
481 encode(&header, &claims, &EncodingKey::from_secret(b"secret")).expect("token encodes")
482 }
483
484 #[test]
485 fn unknown_kid_triggers_one_jwks_refetch() {
486 use crate::http::RecordingHttpClient;
487 let jwks_url = "https://issuer.example/.well-known/jwks.json";
488 let (config, jwks) = hs256_config_and_jwks(jwks_url);
489 let http = RecordingHttpClient::new().with_response(jwks_url, jwks);
490 let auth = JwtAuthenticator::with_http(config, Arc::new(http.clone()));
491
492 let token = hs256_token(Some("rotated-away-key"));
493 let result = auth.verify(&token);
494
495 assert!(matches!(result, Err(JwtAuthError::MissingKey)));
496 assert_eq!(http.requests().len(), 2);
498 }
499
500 #[test]
501 fn missing_kid_with_multiple_keys_is_rejected() {
502 let jwks_url = "https://issuer.example/.well-known/jwks.json";
503 let (config, _) = hs256_config_and_jwks(jwks_url);
504 let http = StaticHttpClient::new().with_response(
505 jwks_url,
506 json!({
507 "keys": [
508 { "kty": "oct", "kid": "a", "alg": "HS256", "k": "c2VjcmV0" },
509 { "kty": "oct", "kid": "b", "alg": "HS256", "k": "b3RoZXI" }
510 ]
511 }),
512 );
513 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
514
515 let token = hs256_token(None);
516 assert!(matches!(auth.verify(&token), Err(JwtAuthError::MissingKid)));
517 }
518
519 #[test]
520 fn authenticator_rejects_mismatched_claimed_subject() {
521 let jwks_url = "https://issuer.example/.well-known/jwks.json";
522 let (config, jwks) = hs256_config_and_jwks(jwks_url);
523 let http = StaticHttpClient::new().with_response(jwks_url, jwks);
524 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
525 let token = hs256_token(Some("test-key"));
526
527 let mismatched = Credentials::new("user_999", token.clone());
529 assert!(auth.verify_credentials(&mismatched).is_err());
530
531 let unclaimed = Credentials::new("", token);
533 assert_eq!(
534 auth.verify_credentials(&unclaimed).expect("verifies"),
535 "user_123"
536 );
537 }
538}