Skip to main content

typesec_integrations/
jwt.rs

1//! JWT/OIDC authentication helpers and a fast claims-backed policy engine.
2
3use std::collections::HashSet;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::{
8    Algorithm, DecodingKey, TokenData, Validation, decode, decode_header,
9    jwk::{Jwk, JwkSet},
10};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tracing::debug;
14use typesec_core::policy::{PolicyEngine, PolicyResult};
15use typesec_core::typestate::{AgentError, Authenticator, Credentials};
16
17use crate::http::{HttpClient, ReqwestHttpClient};
18
19/// OIDC validation settings.
20#[derive(Debug, Clone)]
21pub struct OidcConfig {
22    /// Expected issuer claim.
23    pub issuer: String,
24    /// Expected audience claim.
25    pub audience: String,
26    /// JWKS endpoint used to resolve signing keys.
27    pub jwks_url: String,
28    /// Accepted signing algorithms.
29    pub algorithms: Vec<Algorithm>,
30    /// How long fetched JWKS keys are cached before re-fetching.
31    ///
32    /// The cache is also refreshed eagerly when a token references an unknown
33    /// `kid`, so key rotation at the IdP is picked up without a restart.
34    pub jwks_ttl: Duration,
35}
36
37impl OidcConfig {
38    /// Create a config using RS256, the common AuthKit/OIDC default.
39    pub fn new(
40        issuer: impl Into<String>,
41        audience: impl Into<String>,
42        jwks_url: impl Into<String>,
43    ) -> Self {
44        Self {
45            issuer: issuer.into(),
46            audience: audience.into(),
47            jwks_url: jwks_url.into(),
48            algorithms: vec![Algorithm::RS256],
49            jwks_ttl: Duration::from_secs(300),
50        }
51    }
52}
53
54/// Claims Typesec cares about from an access token.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct JwtClaims {
57    /// Subject identifier.
58    pub sub: String,
59    /// Issuer.
60    pub iss: String,
61    /// Audience. Some providers encode this as a string, others as a list.
62    pub aud: Audience,
63    /// Expiration timestamp.
64    pub exp: usize,
65    /// Optional organization identifier.
66    #[serde(default)]
67    pub org_id: Option<String>,
68    /// Optional organization membership identifier.
69    #[serde(default)]
70    pub organization_membership_id: Option<String>,
71    /// Optional role.
72    #[serde(default)]
73    pub role: Option<String>,
74    /// Optional permission list.
75    #[serde(default)]
76    pub permissions: Vec<String>,
77}
78
79/// JWT audience represented as either a string or list.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(untagged)]
82pub enum Audience {
83    /// Single audience.
84    Single(String),
85    /// Multiple audiences.
86    Multiple(Vec<String>),
87}
88
89impl Audience {
90    fn contains(&self, needle: &str) -> bool {
91        match self {
92            Self::Single(value) => value == needle,
93            Self::Multiple(values) => values.iter().any(|value| value == needle),
94        }
95    }
96}
97
98/// Verified identity extracted from an OIDC/JWT access token.
99#[derive(Debug, Clone, PartialEq, Eq)]
100pub struct VerifiedSubject {
101    /// Subject identifier.
102    pub subject: String,
103    /// Optional organization identifier.
104    pub org_id: Option<String>,
105    /// Optional organization membership identifier.
106    pub organization_membership_id: Option<String>,
107    /// Role names carried by the token.
108    pub roles: Vec<String>,
109    /// Permission names carried by the token.
110    pub permissions: Vec<String>,
111}
112
113impl VerifiedSubject {
114    /// Return the best subject identifier for WorkOS FGA checks.
115    pub fn workos_membership_subject(&self) -> &str {
116        self.organization_membership_id
117            .as_deref()
118            .unwrap_or(&self.subject)
119    }
120}
121
122impl From<JwtClaims> for VerifiedSubject {
123    fn from(claims: JwtClaims) -> Self {
124        Self {
125            subject: claims.sub,
126            org_id: claims.org_id,
127            organization_membership_id: claims.organization_membership_id,
128            roles: claims.role.into_iter().collect(),
129            permissions: claims.permissions,
130        }
131    }
132}
133
134/// JWT authenticator that verifies tokens against a JWKS endpoint.
135pub struct JwtAuthenticator {
136    config: OidcConfig,
137    http: Arc<dyn HttpClient>,
138    jwks: RwLock<Option<CachedJwks>>,
139}
140
141#[derive(Clone)]
142struct CachedJwks {
143    keys: JwkSet,
144    fetched_at: Instant,
145}
146
147impl JwtAuthenticator {
148    /// Create an authenticator using the default reqwest HTTP client.
149    pub fn new(config: OidcConfig) -> Self {
150        Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
151    }
152
153    /// Create an authenticator with an injected HTTP client.
154    pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
155        Self {
156            config,
157            http,
158            jwks: RwLock::new(None),
159        }
160    }
161
162    /// Verify a bearer token and return its Typesec subject model.
163    pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
164        let data = self.decode_claims(token)?;
165        if !data.claims.aud.contains(&self.config.audience) {
166            return Err(JwtAuthError::InvalidAudience);
167        }
168        Ok(data.claims.into())
169    }
170
171    fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
172        let header = decode_header(token)?;
173        let key = self.resolve_key(header.kid.as_deref())?;
174
175        let mut validation = Validation::new(header.alg);
176        validation.algorithms = self.config.algorithms.clone();
177        validation.set_issuer(&[self.config.issuer.as_str()]);
178        validation.set_audience(&[self.config.audience.as_str()]);
179
180        Ok(decode::<JwtClaims>(
181            token,
182            &DecodingKey::from_jwk(&key)?,
183            &validation,
184        )?)
185    }
186
187    /// Resolve the signing key for a token header.
188    ///
189    /// - With a `kid`: look it up in the cached JWKS; on a miss, re-fetch the
190    ///   JWKS once (the IdP may have rotated keys) before failing.
191    /// - Without a `kid`: only unambiguous key sets are accepted — if the JWKS
192    ///   holds more than one key, the token is rejected rather than verified
193    ///   against an arbitrary key.
194    fn resolve_key(&self, kid: Option<&str>) -> Result<Jwk, JwtAuthError> {
195        let jwks = self.jwks(false)?;
196        match kid {
197            Some(kid) => {
198                if let Some(key) = jwks.find(kid) {
199                    return Ok(key.clone());
200                }
201                // Unknown kid — refresh the JWKS once in case of key rotation.
202                let jwks = self.jwks(true)?;
203                jwks.find(kid).cloned().ok_or(JwtAuthError::MissingKey)
204            }
205            None => match jwks.keys.as_slice() {
206                [only] => Ok(only.clone()),
207                [] => Err(JwtAuthError::MissingKey),
208                _ => Err(JwtAuthError::MissingKid),
209            },
210        }
211    }
212
213    fn jwks(&self, force_refresh: bool) -> Result<JwkSet, JwtAuthError> {
214        if !force_refresh
215            && let Some(cached) = self.jwks.read().expect("jwks lock poisoned").as_ref()
216            && cached.fetched_at.elapsed() < self.config.jwks_ttl
217        {
218            return Ok(cached.keys.clone());
219        }
220
221        let value = self.http.get_json(&self.config.jwks_url, &[])?;
222        let keys: JwkSet = serde_json::from_value(value)?;
223        *self.jwks.write().expect("jwks lock poisoned") = Some(CachedJwks {
224            keys: keys.clone(),
225            fetched_at: Instant::now(),
226        });
227        Ok(keys)
228    }
229}
230
231impl Authenticator for JwtAuthenticator {
232    /// Verify the credential token as a JWT and return the *verified* subject.
233    ///
234    /// If the credentials claim a subject, it must match the token's `sub`
235    /// claim — a caller cannot authenticate as someone else's identity by
236    /// pairing a valid token with a different claimed subject.
237    fn verify_credentials(&self, credentials: &Credentials) -> Result<String, AgentError> {
238        let verified =
239            self.verify(credentials.token.expose())
240                .map_err(|e| AgentError::AuthFailed {
241                    reason: format!("jwt verification failed: {e}"),
242                })?;
243        if !credentials.subject.is_empty() && credentials.subject != verified.subject {
244            return Err(AgentError::AuthFailed {
245                reason: format!(
246                    "claimed subject '{}' does not match verified token subject '{}'",
247                    credentials.subject, verified.subject
248                ),
249            });
250        }
251        Ok(verified.subject)
252    }
253}
254
255/// Errors returned by [`JwtAuthenticator`].
256#[derive(Debug, thiserror::Error)]
257pub enum JwtAuthError {
258    /// Token validation failed.
259    #[error("jwt validation failed: {0}")]
260    Jwt(#[from] jsonwebtoken::errors::Error),
261    /// JWKS fetch failed.
262    #[error("jwks fetch failed: {0}")]
263    Http(#[from] Box<dyn std::error::Error + Send + Sync>),
264    /// JWKS JSON could not be parsed.
265    #[error("jwks parse failed: {0}")]
266    Json(#[from] serde_json::Error),
267    /// No matching signing key was found.
268    #[error("no matching signing key found in JWKS")]
269    MissingKey,
270    /// The token has no `kid` and the JWKS holds multiple keys.
271    #[error("token has no kid but JWKS is ambiguous (multiple keys)")]
272    MissingKid,
273    /// Token audience did not match the configured audience.
274    #[error("token audience did not match expected audience")]
275    InvalidAudience,
276}
277
278/// Policy engine backed by verified JWT permission claims.
279///
280/// This is intended as the fast first layer in a composed engine: allow obvious
281/// org-wide permissions from the token and delegate resource-specific decisions
282/// to RBAC, ODRL, WorkOS FGA, or another precise engine.
283pub struct JwtClaimsEngine {
284    subject: String,
285    permissions: HashSet<String>,
286    org_id: Option<String>,
287}
288
289impl JwtClaimsEngine {
290    /// Build an engine from a verified subject.
291    pub fn new(subject: VerifiedSubject) -> Self {
292        Self {
293            subject: subject.subject,
294            permissions: subject.permissions.into_iter().collect(),
295            org_id: subject.org_id,
296        }
297    }
298
299    /// Build an engine from raw permission strings.
300    pub fn from_permissions(
301        subject: impl Into<String>,
302        permissions: impl IntoIterator<Item = String>,
303    ) -> Self {
304        Self {
305            subject: subject.into(),
306            permissions: permissions.into_iter().collect(),
307            org_id: None,
308        }
309    }
310
311    fn permission_matches(&self, action: &str, resource: &str) -> bool {
312        if self.permissions.contains(action) {
313            return true;
314        }
315
316        let resource_type = resource.split(['/', ':']).next().unwrap_or(resource);
317        self.permissions
318            .contains(&format!("{resource_type}:{action}"))
319    }
320}
321
322impl PolicyEngine for JwtClaimsEngine {
323    fn check(&self, subject: &str, action: &str, resource: &str) -> PolicyResult {
324        debug!(subject, action, resource, org_id = ?self.org_id, "jwt claims check");
325
326        if subject != self.subject {
327            return PolicyResult::Delegate(format!(
328                "jwt claims are for '{}', not '{subject}'",
329                self.subject
330            ));
331        }
332
333        if self.permission_matches(action, resource) {
334            PolicyResult::Allow
335        } else {
336            PolicyResult::Delegate(format!("permission '{action}' not present in jwt claims"))
337        }
338    }
339}
340
341#[allow(dead_code)]
342fn _assert_value_send_sync(_: Value) {}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::http::StaticHttpClient;
348    use chrono::{Duration, Utc};
349    use jsonwebtoken::{EncodingKey, Header, encode};
350    use serde_json::json;
351
352    #[test]
353    fn jwt_claims_engine_allows_direct_permission() {
354        let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
355        assert_eq!(
356            engine.check("user_1", "read", "project/123"),
357            PolicyResult::Allow
358        );
359    }
360
361    #[test]
362    fn jwt_claims_engine_allows_resource_type_permission() {
363        let engine = JwtClaimsEngine::from_permissions("user_1", ["project:edit".to_string()]);
364        assert_eq!(
365            engine.check("user_1", "edit", "project/123"),
366            PolicyResult::Allow
367        );
368    }
369
370    #[test]
371    fn jwt_claims_engine_delegates_missing_permission() {
372        let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
373        assert!(matches!(
374            engine.check("user_1", "write", "project/123"),
375            PolicyResult::Delegate(_)
376        ));
377    }
378
379    #[test]
380    fn jwt_authenticator_verifies_hs256_token_from_jwks() {
381        let jwks_url = "https://issuer.example/.well-known/jwks.json";
382        let http = StaticHttpClient::new().with_response(
383            jwks_url,
384            json!({
385                "keys": [{
386                    "kty": "oct",
387                    "kid": "test-key",
388                    "alg": "HS256",
389                    "k": "c2VjcmV0"
390                }]
391            }),
392        );
393        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
394        config.algorithms = vec![Algorithm::HS256];
395        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
396
397        let claims = JwtClaims {
398            sub: "user_123".to_string(),
399            iss: "https://issuer.example".to_string(),
400            aud: Audience::Single("typesec-test".to_string()),
401            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
402            org_id: Some("org_123".to_string()),
403            organization_membership_id: Some("om_123".to_string()),
404            role: Some("org_member".to_string()),
405            permissions: vec!["org:view".to_string(), "project:read".to_string()],
406        };
407        let mut header = Header::new(Algorithm::HS256);
408        header.kid = Some("test-key".to_string());
409        let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
410            .expect("token should encode");
411
412        let verified = auth.verify(&token).expect("token should verify");
413        assert_eq!(verified.subject, "user_123");
414        assert_eq!(verified.workos_membership_subject(), "om_123");
415        assert_eq!(verified.permissions, vec!["org:view", "project:read"]);
416    }
417
418    #[test]
419    fn jwt_authenticator_rejects_wrong_audience() {
420        let jwks_url = "https://issuer.example/.well-known/jwks.json";
421        let http = StaticHttpClient::new().with_response(
422            jwks_url,
423            json!({
424                "keys": [{
425                    "kty": "oct",
426                    "kid": "test-key",
427                    "alg": "HS256",
428                    "k": "c2VjcmV0"
429                }]
430            }),
431        );
432        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
433        config.algorithms = vec![Algorithm::HS256];
434        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
435
436        let claims = JwtClaims {
437            sub: "user_123".to_string(),
438            iss: "https://issuer.example".to_string(),
439            aud: Audience::Single("other-audience".to_string()),
440            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
441            org_id: None,
442            organization_membership_id: None,
443            role: None,
444            permissions: vec![],
445        };
446        let mut header = Header::new(Algorithm::HS256);
447        header.kid = Some("test-key".to_string());
448        let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
449            .expect("token should encode");
450
451        assert!(auth.verify(&token).is_err());
452    }
453
454    fn hs256_config_and_jwks(jwks_url: &str) -> (OidcConfig, serde_json::Value) {
455        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
456        config.algorithms = vec![Algorithm::HS256];
457        let jwks = json!({
458            "keys": [{
459                "kty": "oct",
460                "kid": "test-key",
461                "alg": "HS256",
462                "k": "c2VjcmV0"
463            }]
464        });
465        (config, jwks)
466    }
467
468    fn hs256_token(kid: Option<&str>) -> String {
469        let claims = JwtClaims {
470            sub: "user_123".to_string(),
471            iss: "https://issuer.example".to_string(),
472            aud: Audience::Single("typesec-test".to_string()),
473            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
474            org_id: None,
475            organization_membership_id: None,
476            role: None,
477            permissions: vec![],
478        };
479        let mut header = Header::new(Algorithm::HS256);
480        header.kid = kid.map(str::to_owned);
481        encode(&header, &claims, &EncodingKey::from_secret(b"secret")).expect("token encodes")
482    }
483
484    #[test]
485    fn unknown_kid_triggers_one_jwks_refetch() {
486        use crate::http::RecordingHttpClient;
487        let jwks_url = "https://issuer.example/.well-known/jwks.json";
488        let (config, jwks) = hs256_config_and_jwks(jwks_url);
489        let http = RecordingHttpClient::new().with_response(jwks_url, jwks);
490        let auth = JwtAuthenticator::with_http(config, Arc::new(http.clone()));
491
492        let token = hs256_token(Some("rotated-away-key"));
493        let result = auth.verify(&token);
494
495        assert!(matches!(result, Err(JwtAuthError::MissingKey)));
496        // Initial fetch + one rotation-driven refetch, no more.
497        assert_eq!(http.requests().len(), 2);
498    }
499
500    #[test]
501    fn missing_kid_with_multiple_keys_is_rejected() {
502        let jwks_url = "https://issuer.example/.well-known/jwks.json";
503        let (config, _) = hs256_config_and_jwks(jwks_url);
504        let http = StaticHttpClient::new().with_response(
505            jwks_url,
506            json!({
507                "keys": [
508                    { "kty": "oct", "kid": "a", "alg": "HS256", "k": "c2VjcmV0" },
509                    { "kty": "oct", "kid": "b", "alg": "HS256", "k": "b3RoZXI" }
510                ]
511            }),
512        );
513        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
514
515        let token = hs256_token(None);
516        assert!(matches!(auth.verify(&token), Err(JwtAuthError::MissingKid)));
517    }
518
519    #[test]
520    fn authenticator_rejects_mismatched_claimed_subject() {
521        let jwks_url = "https://issuer.example/.well-known/jwks.json";
522        let (config, jwks) = hs256_config_and_jwks(jwks_url);
523        let http = StaticHttpClient::new().with_response(jwks_url, jwks);
524        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
525        let token = hs256_token(Some("test-key"));
526
527        // Claiming someone else's identity with a valid token must fail.
528        let mismatched = Credentials::new("user_999", token.clone());
529        assert!(auth.verify_credentials(&mismatched).is_err());
530
531        // The verified subject wins; an empty claimed subject is allowed.
532        let unclaimed = Credentials::new("", token);
533        assert_eq!(
534            auth.verify_credentials(&unclaimed).expect("verifies"),
535            "user_123"
536        );
537    }
538}