Skip to main content

typesec_integrations/
jwt.rs

1//! JWT/OIDC authentication helpers and a fast claims-backed policy engine.
2
3use std::collections::HashSet;
4use std::sync::{Arc, RwLock};
5
6use jsonwebtoken::{
7    Algorithm, DecodingKey, TokenData, Validation, decode, decode_header, jwk::JwkSet,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use tracing::debug;
12use typesec_core::policy::{PolicyEngine, PolicyResult};
13
14use crate::http::{HttpClient, ReqwestHttpClient};
15
16/// OIDC validation settings.
17#[derive(Debug, Clone)]
18pub struct OidcConfig {
19    /// Expected issuer claim.
20    pub issuer: String,
21    /// Expected audience claim.
22    pub audience: String,
23    /// JWKS endpoint used to resolve signing keys.
24    pub jwks_url: String,
25    /// Accepted signing algorithms.
26    pub algorithms: Vec<Algorithm>,
27}
28
29impl OidcConfig {
30    /// Create a config using RS256, the common AuthKit/OIDC default.
31    pub fn new(
32        issuer: impl Into<String>,
33        audience: impl Into<String>,
34        jwks_url: impl Into<String>,
35    ) -> Self {
36        Self {
37            issuer: issuer.into(),
38            audience: audience.into(),
39            jwks_url: jwks_url.into(),
40            algorithms: vec![Algorithm::RS256],
41        }
42    }
43}
44
45/// Claims Typesec cares about from an access token.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct JwtClaims {
48    /// Subject identifier.
49    pub sub: String,
50    /// Issuer.
51    pub iss: String,
52    /// Audience. Some providers encode this as a string, others as a list.
53    pub aud: Audience,
54    /// Expiration timestamp.
55    pub exp: usize,
56    /// Optional organization identifier.
57    #[serde(default)]
58    pub org_id: Option<String>,
59    /// Optional organization membership identifier.
60    #[serde(default)]
61    pub organization_membership_id: Option<String>,
62    /// Optional role.
63    #[serde(default)]
64    pub role: Option<String>,
65    /// Optional permission list.
66    #[serde(default)]
67    pub permissions: Vec<String>,
68}
69
70/// JWT audience represented as either a string or list.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(untagged)]
73pub enum Audience {
74    /// Single audience.
75    Single(String),
76    /// Multiple audiences.
77    Multiple(Vec<String>),
78}
79
80impl Audience {
81    fn contains(&self, needle: &str) -> bool {
82        match self {
83            Self::Single(value) => value == needle,
84            Self::Multiple(values) => values.iter().any(|value| value == needle),
85        }
86    }
87}
88
89/// Verified identity extracted from an OIDC/JWT access token.
90#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct VerifiedSubject {
92    /// Subject identifier.
93    pub subject: String,
94    /// Optional organization identifier.
95    pub org_id: Option<String>,
96    /// Optional organization membership identifier.
97    pub organization_membership_id: Option<String>,
98    /// Role names carried by the token.
99    pub roles: Vec<String>,
100    /// Permission names carried by the token.
101    pub permissions: Vec<String>,
102}
103
104impl VerifiedSubject {
105    /// Return the best subject identifier for WorkOS FGA checks.
106    pub fn workos_membership_subject(&self) -> &str {
107        self.organization_membership_id
108            .as_deref()
109            .unwrap_or(&self.subject)
110    }
111}
112
113impl From<JwtClaims> for VerifiedSubject {
114    fn from(claims: JwtClaims) -> Self {
115        Self {
116            subject: claims.sub,
117            org_id: claims.org_id,
118            organization_membership_id: claims.organization_membership_id,
119            roles: claims.role.into_iter().collect(),
120            permissions: claims.permissions,
121        }
122    }
123}
124
125/// JWT authenticator that verifies tokens against a JWKS endpoint.
126pub struct JwtAuthenticator {
127    config: OidcConfig,
128    http: Arc<dyn HttpClient>,
129    jwks: RwLock<Option<JwkSet>>,
130}
131
132impl JwtAuthenticator {
133    /// Create an authenticator using the default reqwest HTTP client.
134    pub fn new(config: OidcConfig) -> Self {
135        Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
136    }
137
138    /// Create an authenticator with an injected HTTP client.
139    pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
140        Self {
141            config,
142            http,
143            jwks: RwLock::new(None),
144        }
145    }
146
147    /// Verify a bearer token and return its Typesec subject model.
148    pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
149        let data = self.decode_claims(token)?;
150        if !data.claims.aud.contains(&self.config.audience) {
151            return Err(JwtAuthError::InvalidAudience);
152        }
153        Ok(data.claims.into())
154    }
155
156    fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
157        let header = decode_header(token)?;
158        let jwks = self.jwks()?;
159        let key = match header.kid.as_deref() {
160            Some(kid) => jwks.find(kid).ok_or(JwtAuthError::MissingKey)?,
161            None => jwks.keys.first().ok_or(JwtAuthError::MissingKey)?,
162        };
163
164        let mut validation = Validation::new(header.alg);
165        validation.algorithms = self.config.algorithms.clone();
166        validation.set_issuer(&[self.config.issuer.as_str()]);
167        validation.set_audience(&[self.config.audience.as_str()]);
168
169        Ok(decode::<JwtClaims>(
170            token,
171            &DecodingKey::from_jwk(key)?,
172            &validation,
173        )?)
174    }
175
176    fn jwks(&self) -> Result<JwkSet, JwtAuthError> {
177        if let Some(jwks) = self.jwks.read().expect("jwks lock poisoned").clone() {
178            return Ok(jwks);
179        }
180
181        let value = self.http.get_json(&self.config.jwks_url, &[])?;
182        let jwks: JwkSet = serde_json::from_value(value)?;
183        *self.jwks.write().expect("jwks lock poisoned") = Some(jwks.clone());
184        Ok(jwks)
185    }
186}
187
188/// Errors returned by [`JwtAuthenticator`].
189#[derive(Debug, thiserror::Error)]
190pub enum JwtAuthError {
191    /// Token validation failed.
192    #[error("jwt validation failed: {0}")]
193    Jwt(#[from] jsonwebtoken::errors::Error),
194    /// JWKS fetch failed.
195    #[error("jwks fetch failed: {0}")]
196    Http(#[from] Box<dyn std::error::Error + Send + Sync>),
197    /// JWKS JSON could not be parsed.
198    #[error("jwks parse failed: {0}")]
199    Json(#[from] serde_json::Error),
200    /// No matching signing key was found.
201    #[error("no matching signing key found in JWKS")]
202    MissingKey,
203    /// Token audience did not match the configured audience.
204    #[error("token audience did not match expected audience")]
205    InvalidAudience,
206}
207
208/// Policy engine backed by verified JWT permission claims.
209///
210/// This is intended as the fast first layer in a composed engine: allow obvious
211/// org-wide permissions from the token and delegate resource-specific decisions
212/// to RBAC, ODRL, WorkOS FGA, or another precise engine.
213pub struct JwtClaimsEngine {
214    subject: String,
215    permissions: HashSet<String>,
216    org_id: Option<String>,
217}
218
219impl JwtClaimsEngine {
220    /// Build an engine from a verified subject.
221    pub fn new(subject: VerifiedSubject) -> Self {
222        Self {
223            subject: subject.subject,
224            permissions: subject.permissions.into_iter().collect(),
225            org_id: subject.org_id,
226        }
227    }
228
229    /// Build an engine from raw permission strings.
230    pub fn from_permissions(
231        subject: impl Into<String>,
232        permissions: impl IntoIterator<Item = String>,
233    ) -> Self {
234        Self {
235            subject: subject.into(),
236            permissions: permissions.into_iter().collect(),
237            org_id: None,
238        }
239    }
240
241    fn permission_matches(&self, action: &str, resource: &str) -> bool {
242        if self.permissions.contains(action) {
243            return true;
244        }
245
246        let resource_type = resource.split(['/', ':']).next().unwrap_or(resource);
247        self.permissions
248            .contains(&format!("{resource_type}:{action}"))
249    }
250}
251
252impl PolicyEngine for JwtClaimsEngine {
253    fn check(&self, subject: &str, action: &str, resource: &str) -> PolicyResult {
254        debug!(subject, action, resource, org_id = ?self.org_id, "jwt claims check");
255
256        if subject != self.subject {
257            return PolicyResult::Delegate(format!(
258                "jwt claims are for '{}', not '{subject}'",
259                self.subject
260            ));
261        }
262
263        if self.permission_matches(action, resource) {
264            PolicyResult::Allow
265        } else {
266            PolicyResult::Delegate(format!("permission '{action}' not present in jwt claims"))
267        }
268    }
269}
270
271#[allow(dead_code)]
272fn _assert_value_send_sync(_: Value) {}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::http::StaticHttpClient;
278    use chrono::{Duration, Utc};
279    use jsonwebtoken::{EncodingKey, Header, encode};
280    use serde_json::json;
281
282    #[test]
283    fn jwt_claims_engine_allows_direct_permission() {
284        let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
285        assert_eq!(
286            engine.check("user_1", "read", "project/123"),
287            PolicyResult::Allow
288        );
289    }
290
291    #[test]
292    fn jwt_claims_engine_allows_resource_type_permission() {
293        let engine = JwtClaimsEngine::from_permissions("user_1", ["project:edit".to_string()]);
294        assert_eq!(
295            engine.check("user_1", "edit", "project/123"),
296            PolicyResult::Allow
297        );
298    }
299
300    #[test]
301    fn jwt_claims_engine_delegates_missing_permission() {
302        let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
303        assert!(matches!(
304            engine.check("user_1", "write", "project/123"),
305            PolicyResult::Delegate(_)
306        ));
307    }
308
309    #[test]
310    fn jwt_authenticator_verifies_hs256_token_from_jwks() {
311        let jwks_url = "https://issuer.example/.well-known/jwks.json";
312        let http = StaticHttpClient::new().with_response(
313            jwks_url,
314            json!({
315                "keys": [{
316                    "kty": "oct",
317                    "kid": "test-key",
318                    "alg": "HS256",
319                    "k": "c2VjcmV0"
320                }]
321            }),
322        );
323        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
324        config.algorithms = vec![Algorithm::HS256];
325        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
326
327        let claims = JwtClaims {
328            sub: "user_123".to_string(),
329            iss: "https://issuer.example".to_string(),
330            aud: Audience::Single("typesec-test".to_string()),
331            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
332            org_id: Some("org_123".to_string()),
333            organization_membership_id: Some("om_123".to_string()),
334            role: Some("org_member".to_string()),
335            permissions: vec!["org:view".to_string(), "project:read".to_string()],
336        };
337        let mut header = Header::new(Algorithm::HS256);
338        header.kid = Some("test-key".to_string());
339        let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
340            .expect("token should encode");
341
342        let verified = auth.verify(&token).expect("token should verify");
343        assert_eq!(verified.subject, "user_123");
344        assert_eq!(verified.workos_membership_subject(), "om_123");
345        assert_eq!(verified.permissions, vec!["org:view", "project:read"]);
346    }
347
348    #[test]
349    fn jwt_authenticator_rejects_wrong_audience() {
350        let jwks_url = "https://issuer.example/.well-known/jwks.json";
351        let http = StaticHttpClient::new().with_response(
352            jwks_url,
353            json!({
354                "keys": [{
355                    "kty": "oct",
356                    "kid": "test-key",
357                    "alg": "HS256",
358                    "k": "c2VjcmV0"
359                }]
360            }),
361        );
362        let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
363        config.algorithms = vec![Algorithm::HS256];
364        let auth = JwtAuthenticator::with_http(config, Arc::new(http));
365
366        let claims = JwtClaims {
367            sub: "user_123".to_string(),
368            iss: "https://issuer.example".to_string(),
369            aud: Audience::Single("other-audience".to_string()),
370            exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
371            org_id: None,
372            organization_membership_id: None,
373            role: None,
374            permissions: vec![],
375        };
376        let mut header = Header::new(Algorithm::HS256);
377        header.kid = Some("test-key".to_string());
378        let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
379            .expect("token should encode");
380
381        assert!(auth.verify(&token).is_err());
382    }
383}