1use std::collections::HashSet;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::{
8 Algorithm, DecodingKey, TokenData, Validation, decode, decode_header,
9 jwk::{Jwk, JwkSet},
10};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tracing::debug;
14use typesec_core::typestate::{AgentError, Authenticator, Credentials};
15use typesec_core::{
16 ResourceId, SubjectId,
17 policy::{PolicyEngine, PolicyResult},
18};
19
20use crate::http::{HttpClient, ReqwestHttpClient};
21
22#[derive(Debug, Clone)]
24pub struct OidcConfig {
25 pub issuer: String,
27 pub audience: String,
29 pub jwks_url: String,
31 pub algorithms: Vec<Algorithm>,
33 pub jwks_ttl: Duration,
38}
39
40impl OidcConfig {
41 pub fn new(
43 issuer: impl Into<String>,
44 audience: impl Into<String>,
45 jwks_url: impl Into<String>,
46 ) -> Self {
47 Self {
48 issuer: issuer.into(),
49 audience: audience.into(),
50 jwks_url: jwks_url.into(),
51 algorithms: vec![Algorithm::RS256],
52 jwks_ttl: Duration::from_secs(300),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct JwtClaims {
60 pub sub: String,
62 pub iss: String,
64 pub aud: Audience,
66 pub exp: usize,
68 #[serde(default)]
70 pub org_id: Option<String>,
71 #[serde(default)]
73 pub organization_membership_id: Option<String>,
74 #[serde(default)]
76 pub role: Option<String>,
77 #[serde(default)]
79 pub permissions: Vec<String>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(untagged)]
85pub enum Audience {
86 Single(String),
88 Multiple(Vec<String>),
90}
91
92impl Audience {
93 fn contains(&self, needle: &str) -> bool {
94 match self {
95 Self::Single(value) => value == needle,
96 Self::Multiple(values) => values.iter().any(|value| value == needle),
97 }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
103pub struct VerifiedSubject {
104 pub subject: String,
106 pub org_id: Option<String>,
108 pub organization_membership_id: Option<String>,
110 pub roles: Vec<String>,
112 pub permissions: Vec<String>,
114}
115
116impl VerifiedSubject {
117 pub fn workos_membership_subject(&self) -> &str {
119 self.organization_membership_id
120 .as_deref()
121 .unwrap_or(&self.subject)
122 }
123}
124
125impl From<JwtClaims> for VerifiedSubject {
126 fn from(claims: JwtClaims) -> Self {
127 Self {
128 subject: claims.sub,
129 org_id: claims.org_id,
130 organization_membership_id: claims.organization_membership_id,
131 roles: claims.role.into_iter().collect(),
132 permissions: claims.permissions,
133 }
134 }
135}
136
137pub struct JwtAuthenticator {
139 config: OidcConfig,
140 http: Arc<dyn HttpClient>,
141 jwks: RwLock<Option<CachedJwks>>,
142}
143
144#[derive(Clone)]
145struct CachedJwks {
146 keys: JwkSet,
147 fetched_at: Instant,
148}
149
150impl JwtAuthenticator {
151 pub fn new(config: OidcConfig) -> Self {
153 Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
154 }
155
156 pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
158 Self {
159 config,
160 http,
161 jwks: RwLock::new(None),
162 }
163 }
164
165 pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
167 let data = self.decode_claims(token)?;
168 if !data.claims.aud.contains(&self.config.audience) {
169 return Err(JwtAuthError::InvalidAudience);
170 }
171 Ok(data.claims.into())
172 }
173
174 fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
175 let header = decode_header(token)?;
176 let key = self.resolve_key(header.kid.as_deref())?;
177
178 let mut validation = Validation::new(header.alg);
179 validation.algorithms = self.config.algorithms.clone();
180 validation.set_issuer(&[self.config.issuer.as_str()]);
181 validation.set_audience(&[self.config.audience.as_str()]);
182
183 Ok(decode::<JwtClaims>(
184 token,
185 &DecodingKey::from_jwk(&key)?,
186 &validation,
187 )?)
188 }
189
190 fn resolve_key(&self, kid: Option<&str>) -> Result<Jwk, JwtAuthError> {
198 let jwks = self.jwks(false)?;
199 match kid {
200 Some(kid) => {
201 if let Some(key) = jwks.find(kid) {
202 return Ok(key.clone());
203 }
204 let jwks = self.jwks(true)?;
206 jwks.find(kid).cloned().ok_or(JwtAuthError::MissingKey)
207 }
208 None => match jwks.keys.as_slice() {
209 [only] => Ok(only.clone()),
210 [] => Err(JwtAuthError::MissingKey),
211 _ => Err(JwtAuthError::MissingKid),
212 },
213 }
214 }
215
216 fn jwks(&self, force_refresh: bool) -> Result<JwkSet, JwtAuthError> {
217 if !force_refresh
218 && let Some(cached) = self.jwks.read().expect("jwks lock poisoned").as_ref()
219 && cached.fetched_at.elapsed() < self.config.jwks_ttl
220 {
221 return Ok(cached.keys.clone());
222 }
223
224 let value = self.http.get_json(&self.config.jwks_url, &[])?;
225 let keys: JwkSet = serde_json::from_value(value)?;
226 *self.jwks.write().expect("jwks lock poisoned") = Some(CachedJwks {
227 keys: keys.clone(),
228 fetched_at: Instant::now(),
229 });
230 Ok(keys)
231 }
232}
233
234impl Authenticator for JwtAuthenticator {
235 fn verify_credentials(&self, credentials: &Credentials) -> Result<String, AgentError> {
241 let verified =
242 self.verify(credentials.token.expose())
243 .map_err(|e| AgentError::AuthFailed {
244 reason: format!("jwt verification failed: {e}"),
245 })?;
246 if !credentials.subject.is_empty() && credentials.subject != verified.subject {
247 return Err(AgentError::AuthFailed {
248 reason: format!(
249 "claimed subject '{}' does not match verified token subject '{}'",
250 credentials.subject, verified.subject
251 ),
252 });
253 }
254 Ok(verified.subject)
255 }
256}
257
258#[derive(Debug, thiserror::Error)]
260pub enum JwtAuthError {
261 #[error("jwt validation failed: {0}")]
263 Jwt(#[from] jsonwebtoken::errors::Error),
264 #[error("jwks fetch failed: {0}")]
266 Http(#[from] Box<dyn std::error::Error + Send + Sync>),
267 #[error("jwks parse failed: {0}")]
269 Json(#[from] serde_json::Error),
270 #[error("no matching signing key found in JWKS")]
272 MissingKey,
273 #[error("token has no kid but JWKS is ambiguous (multiple keys)")]
275 MissingKid,
276 #[error("token audience did not match expected audience")]
278 InvalidAudience,
279}
280
281pub struct JwtClaimsEngine {
287 subject: String,
288 permissions: HashSet<String>,
289 org_id: Option<String>,
290}
291
292impl JwtClaimsEngine {
293 pub fn new(subject: VerifiedSubject) -> Self {
295 Self {
296 subject: subject.subject,
297 permissions: subject.permissions.into_iter().collect(),
298 org_id: subject.org_id,
299 }
300 }
301
302 pub fn from_permissions(
304 subject: impl Into<String>,
305 permissions: impl IntoIterator<Item = String>,
306 ) -> Self {
307 Self {
308 subject: subject.into(),
309 permissions: permissions.into_iter().collect(),
310 org_id: None,
311 }
312 }
313
314 fn permission_matches(&self, action: &str, resource: &str) -> bool {
315 if self.permissions.contains(action) {
316 return true;
317 }
318
319 let resource_type = resource.split(['/', ':']).next().unwrap_or(resource);
320 self.permissions
321 .contains(&format!("{resource_type}:{action}"))
322 }
323}
324
325impl PolicyEngine for JwtClaimsEngine {
326 fn check(&self, subject: &SubjectId, action: &str, resource: &ResourceId) -> PolicyResult {
327 let subject = subject.as_str();
328 let resource = resource.as_str();
329 debug!(subject, action, resource, org_id = ?self.org_id, "jwt claims check");
330
331 if subject != self.subject {
332 return PolicyResult::delegate(
333 "jwt",
334 format!("jwt claims are for '{}', not '{subject}'", self.subject),
335 );
336 }
337
338 if self.permission_matches(action, resource) {
339 PolicyResult::Allow
340 } else {
341 PolicyResult::delegate(
342 "jwt",
343 format!("permission '{action}' not present in jwt claims"),
344 )
345 }
346 }
347}
348
349#[allow(dead_code)]
350fn _assert_value_send_sync(_: Value) {}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::http::StaticHttpClient;
356 use chrono::{Duration, Utc};
357 use jsonwebtoken::{EncodingKey, Header, encode};
358 use serde_json::json;
359
360 fn check(
361 engine: &JwtClaimsEngine,
362 subject: &str,
363 action: &str,
364 resource: &str,
365 ) -> PolicyResult {
366 engine.check(
367 &SubjectId::from(subject),
368 action,
369 &ResourceId::from(resource),
370 )
371 }
372
373 #[test]
374 fn jwt_claims_engine_allows_direct_permission() {
375 let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
376 assert_eq!(
377 check(&engine, "user_1", "read", "project/123"),
378 PolicyResult::Allow
379 );
380 }
381
382 #[test]
383 fn jwt_claims_engine_allows_resource_type_permission() {
384 let engine = JwtClaimsEngine::from_permissions("user_1", ["project:edit".to_string()]);
385 assert_eq!(
386 check(&engine, "user_1", "edit", "project/123"),
387 PolicyResult::Allow
388 );
389 }
390
391 #[test]
392 fn jwt_claims_engine_delegates_missing_permission() {
393 let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
394 assert!(matches!(
395 check(&engine, "user_1", "write", "project/123"),
396 PolicyResult::Delegate(_)
397 ));
398 }
399
400 #[test]
401 fn jwt_authenticator_verifies_hs256_token_from_jwks() {
402 let jwks_url = "https://issuer.example/.well-known/jwks.json";
403 let http = StaticHttpClient::new().with_response(
404 jwks_url,
405 json!({
406 "keys": [{
407 "kty": "oct",
408 "kid": "test-key",
409 "alg": "HS256",
410 "k": "c2VjcmV0"
411 }]
412 }),
413 );
414 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
415 config.algorithms = vec![Algorithm::HS256];
416 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
417
418 let claims = JwtClaims {
419 sub: "user_123".to_string(),
420 iss: "https://issuer.example".to_string(),
421 aud: Audience::Single("typesec-test".to_string()),
422 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
423 org_id: Some("org_123".to_string()),
424 organization_membership_id: Some("om_123".to_string()),
425 role: Some("org_member".to_string()),
426 permissions: vec!["org:view".to_string(), "project:read".to_string()],
427 };
428 let mut header = Header::new(Algorithm::HS256);
429 header.kid = Some("test-key".to_string());
430 let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
431 .expect("token should encode");
432
433 let verified = auth.verify(&token).expect("token should verify");
434 assert_eq!(verified.subject, "user_123");
435 assert_eq!(verified.workos_membership_subject(), "om_123");
436 assert_eq!(verified.permissions, vec!["org:view", "project:read"]);
437 }
438
439 #[test]
440 fn jwt_authenticator_rejects_wrong_audience() {
441 let jwks_url = "https://issuer.example/.well-known/jwks.json";
442 let http = StaticHttpClient::new().with_response(
443 jwks_url,
444 json!({
445 "keys": [{
446 "kty": "oct",
447 "kid": "test-key",
448 "alg": "HS256",
449 "k": "c2VjcmV0"
450 }]
451 }),
452 );
453 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
454 config.algorithms = vec![Algorithm::HS256];
455 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
456
457 let claims = JwtClaims {
458 sub: "user_123".to_string(),
459 iss: "https://issuer.example".to_string(),
460 aud: Audience::Single("other-audience".to_string()),
461 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
462 org_id: None,
463 organization_membership_id: None,
464 role: None,
465 permissions: vec![],
466 };
467 let mut header = Header::new(Algorithm::HS256);
468 header.kid = Some("test-key".to_string());
469 let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
470 .expect("token should encode");
471
472 assert!(auth.verify(&token).is_err());
473 }
474
475 fn hs256_config_and_jwks(jwks_url: &str) -> (OidcConfig, serde_json::Value) {
476 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
477 config.algorithms = vec![Algorithm::HS256];
478 let jwks = json!({
479 "keys": [{
480 "kty": "oct",
481 "kid": "test-key",
482 "alg": "HS256",
483 "k": "c2VjcmV0"
484 }]
485 });
486 (config, jwks)
487 }
488
489 fn hs256_token(kid: Option<&str>) -> String {
490 let claims = JwtClaims {
491 sub: "user_123".to_string(),
492 iss: "https://issuer.example".to_string(),
493 aud: Audience::Single("typesec-test".to_string()),
494 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
495 org_id: None,
496 organization_membership_id: None,
497 role: None,
498 permissions: vec![],
499 };
500 let mut header = Header::new(Algorithm::HS256);
501 header.kid = kid.map(str::to_owned);
502 encode(&header, &claims, &EncodingKey::from_secret(b"secret")).expect("token encodes")
503 }
504
505 #[test]
506 fn unknown_kid_triggers_one_jwks_refetch() {
507 use crate::http::RecordingHttpClient;
508 let jwks_url = "https://issuer.example/.well-known/jwks.json";
509 let (config, jwks) = hs256_config_and_jwks(jwks_url);
510 let http = RecordingHttpClient::new().with_response(jwks_url, jwks);
511 let auth = JwtAuthenticator::with_http(config, Arc::new(http.clone()));
512
513 let token = hs256_token(Some("rotated-away-key"));
514 let result = auth.verify(&token);
515
516 assert!(matches!(result, Err(JwtAuthError::MissingKey)));
517 assert_eq!(http.requests().len(), 2);
519 }
520
521 #[test]
522 fn missing_kid_with_multiple_keys_is_rejected() {
523 let jwks_url = "https://issuer.example/.well-known/jwks.json";
524 let (config, _) = hs256_config_and_jwks(jwks_url);
525 let http = StaticHttpClient::new().with_response(
526 jwks_url,
527 json!({
528 "keys": [
529 { "kty": "oct", "kid": "a", "alg": "HS256", "k": "c2VjcmV0" },
530 { "kty": "oct", "kid": "b", "alg": "HS256", "k": "b3RoZXI" }
531 ]
532 }),
533 );
534 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
535
536 let token = hs256_token(None);
537 assert!(matches!(auth.verify(&token), Err(JwtAuthError::MissingKid)));
538 }
539
540 #[test]
541 fn authenticator_rejects_mismatched_claimed_subject() {
542 let jwks_url = "https://issuer.example/.well-known/jwks.json";
543 let (config, jwks) = hs256_config_and_jwks(jwks_url);
544 let http = StaticHttpClient::new().with_response(jwks_url, jwks);
545 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
546 let token = hs256_token(Some("test-key"));
547
548 let mismatched = Credentials::new("user_999", token.clone());
550 assert!(auth.verify_credentials(&mismatched).is_err());
551
552 let unclaimed = Credentials::new("", token);
554 assert_eq!(
555 auth.verify_credentials(&unclaimed).expect("verifies"),
556 "user_123"
557 );
558 }
559}