Skip to main content

typesec_integrations/
jwt.rs

1//! JWT/OIDC authentication helpers and a fast claims-backed policy engine.
2
3use std::collections::HashSet;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::{
8    Algorithm, DecodingKey, TokenData, Validation, decode, decode_header,
9    jwk::{Jwk, JwkSet},
10};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tracing::debug;
14use typesec_core::typestate::{AgentError, Authenticator, Credentials};
15use typesec_core::{
16    ResourceId, SubjectId,
17    policy::{PolicyEngine, PolicyResult},
18};
19
20use crate::http::{HttpClient, ReqwestHttpClient};
21
22/// OIDC validation settings.
23#[derive(Debug, Clone)]
24pub struct OidcConfig {
25    /// Expected issuer claim.
26    pub issuer: String,
27    /// Expected audience claim.
28    pub audience: String,
29    /// JWKS endpoint used to resolve signing keys.
30    pub jwks_url: String,
31    /// Accepted signing algorithms.
32    pub algorithms: Vec<Algorithm>,
33    /// How long fetched JWKS keys are cached before re-fetching.
34    ///
35    /// The cache is also refreshed eagerly when a token references an unknown
36    /// `kid`, so key rotation at the IdP is picked up without a restart.
37    pub jwks_ttl: Duration,
38}
39
40impl OidcConfig {
41    /// Create a config using RS256, the common AuthKit/OIDC default.
42    pub fn new(
43        issuer: impl Into<String>,
44        audience: impl Into<String>,
45        jwks_url: impl Into<String>,
46    ) -> Self {
47        Self {
48            issuer: issuer.into(),
49            audience: audience.into(),
50            jwks_url: jwks_url.into(),
51            algorithms: vec![Algorithm::RS256],
52            jwks_ttl: Duration::from_secs(300),
53        }
54    }
55}
56
57/// Claims Typesec cares about from an access token.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct JwtClaims {
60    /// Subject identifier.
61    pub sub: String,
62    /// Issuer.
63    pub iss: String,
64    /// Audience. Some providers encode this as a string, others as a list.
65    pub aud: Audience,
66    /// Expiration timestamp.
67    pub exp: usize,
68    /// Optional organization identifier.
69    #[serde(default)]
70    pub org_id: Option<String>,
71    /// Optional organization membership identifier.
72    #[serde(default)]
73    pub organization_membership_id: Option<String>,
74    /// Optional role.
75    #[serde(default)]
76    pub role: Option<String>,
77    /// Optional permission list.
78    #[serde(default)]
79    pub permissions: Vec<String>,
80}
81
82/// JWT audience represented as either a string or list.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(untagged)]
85pub enum Audience {
86    /// Single audience.
87    Single(String),
88    /// Multiple audiences.
89    Multiple(Vec<String>),
90}
91
92impl Audience {
93    fn contains(&self, needle: &str) -> bool {
94        match self {
95            Self::Single(value) => value == needle,
96            Self::Multiple(values) => values.iter().any(|value| value == needle),
97        }
98    }
99}
100
101/// Verified identity extracted from an OIDC/JWT access token.
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub struct VerifiedSubject {
104    /// Subject identifier.
105    pub subject: String,
106    /// Optional organization identifier.
107    pub org_id: Option<String>,
108    /// Optional organization membership identifier.
109    pub organization_membership_id: Option<String>,
110    /// Role names carried by the token.
111    pub roles: Vec<String>,
112    /// Permission names carried by the token.
113    pub permissions: Vec<String>,
114}
115
116impl VerifiedSubject {
117    /// Return the best subject identifier for WorkOS FGA checks.
118    pub fn workos_membership_subject(&self) -> &str {
119        self.organization_membership_id
120            .as_deref()
121            .unwrap_or(&self.subject)
122    }
123}
124
125impl From<JwtClaims> for VerifiedSubject {
126    fn from(claims: JwtClaims) -> Self {
127        Self {
128            subject: claims.sub,
129            org_id: claims.org_id,
130            organization_membership_id: claims.organization_membership_id,
131            roles: claims.role.into_iter().collect(),
132            permissions: claims.permissions,
133        }
134    }
135}
136
137/// JWT authenticator that verifies tokens against a JWKS endpoint.
138pub struct JwtAuthenticator {
139    config: OidcConfig,
140    http: Arc<dyn HttpClient>,
141    jwks: RwLock<Option<CachedJwks>>,
142}
143
144#[derive(Clone)]
145struct CachedJwks {
146    keys: JwkSet,
147    fetched_at: Instant,
148}
149
150impl JwtAuthenticator {
151    /// Create an authenticator using the default reqwest HTTP client.
152    pub fn new(config: OidcConfig) -> Self {
153        Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
154    }
155
156    /// Create an authenticator with an injected HTTP client.
157    pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
158        Self {
159            config,
160            http,
161            jwks: RwLock::new(None),
162        }
163    }
164
165    /// Verify a bearer token and return its Typesec subject model.
166    pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
167        let data = self.decode_claims(token)?;
168        if !data.claims.aud.contains(&self.config.audience) {
169            return Err(JwtAuthError::InvalidAudience);
170        }
171        Ok(data.claims.into())
172    }
173
174    fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
175        let header = decode_header(token)?;
176        let key = self.resolve_key(header.kid.as_deref())?;
177
178        let mut validation = Validation::new(header.alg);
179        validation.algorithms = self.config.algorithms.clone();
180        validation.set_issuer(&[self.config.issuer.as_str()]);
181        validation.set_audience(&[self.config.audience.as_str()]);
182
183        Ok(decode::<JwtClaims>(
184            token,
185            &DecodingKey::from_jwk(&key)?,
186            &validation,
187        )?)
188    }
189
190    /// Resolve the signing key for a token header.
191    ///
192    /// - With a `kid`: look it up in the cached JWKS; on a miss, re-fetch the
193    ///   JWKS once (the IdP may have rotated keys) before failing.
194    /// - Without a `kid`: only unambiguous key sets are accepted — if the JWKS
195    ///   holds more than one key, the token is rejected rather than verified
196    ///   against an arbitrary key.
197    fn resolve_key(&self, kid: Option<&str>) -> Result<Jwk, JwtAuthError> {
198        let jwks = self.jwks(false)?;
199        match kid {
200            Some(kid) => {
201                if let Some(key) = jwks.find(kid) {
202                    return Ok(key.clone());
203                }
204                // Unknown kid — refresh the JWKS once in case of key rotation.
205                let jwks = self.jwks(true)?;
206                jwks.find(kid).cloned().ok_or(JwtAuthError::MissingKey)
207            }
208            None => match jwks.keys.as_slice() {
209                [only] => Ok(only.clone()),
210                [] => Err(JwtAuthError::MissingKey),
211                _ => Err(JwtAuthError::MissingKid),
212            },
213        }
214    }
215
216    fn jwks(&self, force_refresh: bool) -> Result<JwkSet, JwtAuthError> {
217        if !force_refresh
218            && let Some(cached) = self.jwks.read().expect("jwks lock poisoned").as_ref()
219            && cached.fetched_at.elapsed() < self.config.jwks_ttl
220        {
221            return Ok(cached.keys.clone());
222        }
223
224        let value = self.http.get_json(&self.config.jwks_url, &[])?;
225        let keys: JwkSet = serde_json::from_value(value)?;
226        *self.jwks.write().expect("jwks lock poisoned") = Some(CachedJwks {
227            keys: keys.clone(),
228            fetched_at: Instant::now(),
229        });
230        Ok(keys)
231    }
232}
233
234impl Authenticator for JwtAuthenticator {
235    /// Verify the credential token as a JWT and return the *verified* subject.
236    ///
237    /// If the credentials claim a subject, it must match the token's `sub`
238    /// claim — a caller cannot authenticate as someone else's identity by
239    /// pairing a valid token with a different claimed subject.
240    fn verify_credentials(&self, credentials: &Credentials) -> Result<String, AgentError> {
241        let verified =
242            self.verify(credentials.token.expose())
243                .map_err(|e| AgentError::AuthFailed {
244                    reason: format!("jwt verification failed: {e}"),
245                })?;
246        if !credentials.subject.is_empty() && credentials.subject != verified.subject {
247            return Err(AgentError::AuthFailed {
248                reason: format!(
249                    "claimed subject '{}' does not match verified token subject '{}'",
250                    credentials.subject, verified.subject
251                ),
252            });
253        }
254        Ok(verified.subject)
255    }
256}
257
258/// Errors returned by [`JwtAuthenticator`].
259#[derive(Debug, thiserror::Error)]
260pub enum JwtAuthError {
261    /// Token validation failed.
262    #[error("jwt validation failed: {0}")]
263    Jwt(#[from] jsonwebtoken::errors::Error),
264    /// JWKS fetch failed.
265    #[error("jwks fetch failed: {0}")]
266    Http(#[from] Box<dyn std::error::Error + Send + Sync>),
267    /// JWKS JSON could not be parsed.
268    #[error("jwks parse failed: {0}")]
269    Json(#[from] serde_json::Error),
270    /// No matching signing key was found.
271    #[error("no matching signing key found in JWKS")]
272    MissingKey,
273    /// The token has no `kid` and the JWKS holds multiple keys.
274    #[error("token has no kid but JWKS is ambiguous (multiple keys)")]
275    MissingKid,
276    /// Token audience did not match the configured audience.
277    #[error("token audience did not match expected audience")]
278    InvalidAudience,
279}
280
281/// Policy engine backed by verified JWT permission claims.
282///
283/// This is intended as the fast first layer in a composed engine: allow obvious
284/// org-wide permissions from the token and delegate resource-specific decisions
285/// to RBAC, ODRL, WorkOS FGA, or another precise engine.
286pub struct JwtClaimsEngine {
287    subject: String,
288    permissions: HashSet<String>,
289    org_id: Option<String>,
290}
291
292impl JwtClaimsEngine {
293    /// Build an engine from a verified subject.
294    pub fn new(subject: VerifiedSubject) -> Self {
295        Self {
296            subject: subject.subject,
297            permissions: subject.permissions.into_iter().collect(),
298            org_id: subject.org_id,
299        }
300    }
301
302    /// Build an engine from raw permission strings.
303    pub fn from_permissions(
304        subject: impl Into<String>,
305        permissions: impl IntoIterator<Item = String>,
306    ) -> Self {
307        Self {
308            subject: subject.into(),
309            permissions: permissions.into_iter().collect(),
310            org_id: None,
311        }
312    }
313
314    fn permission_matches(&self, action: &str, resource: &str) -> bool {
315        if self.permissions.contains(action) {
316            return true;
317        }
318
319        let resource_type = resource.split(['/', ':']).next().unwrap_or(resource);
320        self.permissions
321            .contains(&format!("{resource_type}:{action}"))
322    }
323}
324
325impl PolicyEngine for JwtClaimsEngine {
326    fn check(&self, subject: &SubjectId, action: &str, resource: &ResourceId) -> PolicyResult {
327        let subject = subject.as_str();
328        let resource = resource.as_str();
329        debug!(subject, action, resource, org_id = ?self.org_id, "jwt claims check");
330
331        if subject != self.subject {
332            return PolicyResult::delegate(
333                "jwt",
334                format!("jwt claims are for '{}', not '{subject}'", self.subject),
335            );
336        }
337
338        if self.permission_matches(action, resource) {
339            PolicyResult::Allow
340        } else {
341            PolicyResult::delegate(
342                "jwt",
343                format!("permission '{action}' not present in jwt claims"),
344            )
345        }
346    }
347}
348
349#[allow(dead_code)]
350fn _assert_value_send_sync(_: Value) {}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::http::StaticHttpClient;
356    use chrono::{Duration, Utc};
357    use jsonwebtoken::{EncodingKey, Header, encode};
358    use serde_json::json;
359
360    fn check(
361        engine: &JwtClaimsEngine,
362        subject: &str,
363        action: &str,
364        resource: &str,
365    ) -> PolicyResult {
366        engine.check(
367            &SubjectId::from(subject),
368            action,
369            &ResourceId::from(resource),
370        )
371    }
372
373    #[test]
374    fn jwt_claims_engine_allows_direct_permission() {
375        let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
376        assert_eq!(
377            check(&engine, "user_1", "read", "project/123"),
378            PolicyResult::Allow
379        );
380    }
381
382    #[test]
383    fn jwt_claims_engine_allows_resource_type_permission() {
384        let engine = JwtClaimsEngine::from_permissions("user_1", ["project:edit".to_string()]);
385        assert_eq!(
386            check(&engine, "user_1", "edit", "project/123"),
387            PolicyResult::Allow
388        );
389    }
390
391    #[test]
392    fn jwt_claims_engine_delegates_missing_permission() {
393        let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
394        assert!(matches!(
395            check(&engine, "user_1", "write", "project/123"),
396            PolicyResult::Delegate(_)
397        ));
398    }
399
400    #[test]
401    fn jwt_authenticator_verifies_hs256_token_from_jwks() {
402        let jwks_url = "https://issuer.example/.well-known/jwks.json";
403        let http = StaticHttpClient::new().with_response(
404            jwks_url,
405            json!({
406                "keys": [{
407                    "kty": "oct",
408                    "kid": "test-key",
409                    "alg": "HS256",
410                    "k": "c2VjcmV0"
411                }]
412            }),
413        );
414        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
415        config.algorithms = vec![Algorithm::HS256];
416        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
417
418        let claims = JwtClaims {
419            sub: "user_123".to_string(),
420            iss: "https://issuer.example".to_string(),
421            aud: Audience::Single("typesec-test".to_string()),
422            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
423            org_id: Some("org_123".to_string()),
424            organization_membership_id: Some("om_123".to_string()),
425            role: Some("org_member".to_string()),
426            permissions: vec!["org:view".to_string(), "project:read".to_string()],
427        };
428        let mut header = Header::new(Algorithm::HS256);
429        header.kid = Some("test-key".to_string());
430        let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
431            .expect("token should encode");
432
433        let verified = auth.verify(&token).expect("token should verify");
434        assert_eq!(verified.subject, "user_123");
435        assert_eq!(verified.workos_membership_subject(), "om_123");
436        assert_eq!(verified.permissions, vec!["org:view", "project:read"]);
437    }
438
439    #[test]
440    fn jwt_authenticator_rejects_wrong_audience() {
441        let jwks_url = "https://issuer.example/.well-known/jwks.json";
442        let http = StaticHttpClient::new().with_response(
443            jwks_url,
444            json!({
445                "keys": [{
446                    "kty": "oct",
447                    "kid": "test-key",
448                    "alg": "HS256",
449                    "k": "c2VjcmV0"
450                }]
451            }),
452        );
453        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
454        config.algorithms = vec![Algorithm::HS256];
455        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
456
457        let claims = JwtClaims {
458            sub: "user_123".to_string(),
459            iss: "https://issuer.example".to_string(),
460            aud: Audience::Single("other-audience".to_string()),
461            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
462            org_id: None,
463            organization_membership_id: None,
464            role: None,
465            permissions: vec![],
466        };
467        let mut header = Header::new(Algorithm::HS256);
468        header.kid = Some("test-key".to_string());
469        let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
470            .expect("token should encode");
471
472        assert!(auth.verify(&token).is_err());
473    }
474
475    fn hs256_config_and_jwks(jwks_url: &str) -> (OidcConfig, serde_json::Value) {
476        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
477        config.algorithms = vec![Algorithm::HS256];
478        let jwks = json!({
479            "keys": [{
480                "kty": "oct",
481                "kid": "test-key",
482                "alg": "HS256",
483                "k": "c2VjcmV0"
484            }]
485        });
486        (config, jwks)
487    }
488
489    fn hs256_token(kid: Option<&str>) -> String {
490        let claims = JwtClaims {
491            sub: "user_123".to_string(),
492            iss: "https://issuer.example".to_string(),
493            aud: Audience::Single("typesec-test".to_string()),
494            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
495            org_id: None,
496            organization_membership_id: None,
497            role: None,
498            permissions: vec![],
499        };
500        let mut header = Header::new(Algorithm::HS256);
501        header.kid = kid.map(str::to_owned);
502        encode(&header, &claims, &EncodingKey::from_secret(b"secret")).expect("token encodes")
503    }
504
505    #[test]
506    fn unknown_kid_triggers_one_jwks_refetch() {
507        use crate::http::RecordingHttpClient;
508        let jwks_url = "https://issuer.example/.well-known/jwks.json";
509        let (config, jwks) = hs256_config_and_jwks(jwks_url);
510        let http = RecordingHttpClient::new().with_response(jwks_url, jwks);
511        let auth = JwtAuthenticator::with_http(config, Arc::new(http.clone()));
512
513        let token = hs256_token(Some("rotated-away-key"));
514        let result = auth.verify(&token);
515
516        assert!(matches!(result, Err(JwtAuthError::MissingKey)));
517        // Initial fetch + one rotation-driven refetch, no more.
518        assert_eq!(http.requests().len(), 2);
519    }
520
521    #[test]
522    fn missing_kid_with_multiple_keys_is_rejected() {
523        let jwks_url = "https://issuer.example/.well-known/jwks.json";
524        let (config, _) = hs256_config_and_jwks(jwks_url);
525        let http = StaticHttpClient::new().with_response(
526            jwks_url,
527            json!({
528                "keys": [
529                    { "kty": "oct", "kid": "a", "alg": "HS256", "k": "c2VjcmV0" },
530                    { "kty": "oct", "kid": "b", "alg": "HS256", "k": "b3RoZXI" }
531                ]
532            }),
533        );
534        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
535
536        let token = hs256_token(None);
537        assert!(matches!(auth.verify(&token), Err(JwtAuthError::MissingKid)));
538    }
539
540    #[test]
541    fn authenticator_rejects_mismatched_claimed_subject() {
542        let jwks_url = "https://issuer.example/.well-known/jwks.json";
543        let (config, jwks) = hs256_config_and_jwks(jwks_url);
544        let http = StaticHttpClient::new().with_response(jwks_url, jwks);
545        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
546        let token = hs256_token(Some("test-key"));
547
548        // Claiming someone else's identity with a valid token must fail.
549        let mismatched = Credentials::new("user_999", token.clone());
550        assert!(auth.verify_credentials(&mismatched).is_err());
551
552        // The verified subject wins; an empty claimed subject is allowed.
553        let unclaimed = Credentials::new("", token);
554        assert_eq!(
555            auth.verify_credentials(&unclaimed).expect("verifies"),
556            "user_123"
557        );
558    }
559}