1use std::collections::HashSet;
4use std::sync::{Arc, RwLock};
5
6use jsonwebtoken::{
7 Algorithm, DecodingKey, TokenData, Validation, decode, decode_header, jwk::JwkSet,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use tracing::debug;
12use typesec_core::policy::{PolicyEngine, PolicyResult};
13
14use crate::http::{HttpClient, ReqwestHttpClient};
15
16#[derive(Debug, Clone)]
18pub struct OidcConfig {
19 pub issuer: String,
21 pub audience: String,
23 pub jwks_url: String,
25 pub algorithms: Vec<Algorithm>,
27}
28
29impl OidcConfig {
30 pub fn new(
32 issuer: impl Into<String>,
33 audience: impl Into<String>,
34 jwks_url: impl Into<String>,
35 ) -> Self {
36 Self {
37 issuer: issuer.into(),
38 audience: audience.into(),
39 jwks_url: jwks_url.into(),
40 algorithms: vec![Algorithm::RS256],
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct JwtClaims {
48 pub sub: String,
50 pub iss: String,
52 pub aud: Audience,
54 pub exp: usize,
56 #[serde(default)]
58 pub org_id: Option<String>,
59 #[serde(default)]
61 pub organization_membership_id: Option<String>,
62 #[serde(default)]
64 pub role: Option<String>,
65 #[serde(default)]
67 pub permissions: Vec<String>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(untagged)]
73pub enum Audience {
74 Single(String),
76 Multiple(Vec<String>),
78}
79
80impl Audience {
81 fn contains(&self, needle: &str) -> bool {
82 match self {
83 Self::Single(value) => value == needle,
84 Self::Multiple(values) => values.iter().any(|value| value == needle),
85 }
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct VerifiedSubject {
92 pub subject: String,
94 pub org_id: Option<String>,
96 pub organization_membership_id: Option<String>,
98 pub roles: Vec<String>,
100 pub permissions: Vec<String>,
102}
103
104impl VerifiedSubject {
105 pub fn workos_membership_subject(&self) -> &str {
107 self.organization_membership_id
108 .as_deref()
109 .unwrap_or(&self.subject)
110 }
111}
112
113impl From<JwtClaims> for VerifiedSubject {
114 fn from(claims: JwtClaims) -> Self {
115 Self {
116 subject: claims.sub,
117 org_id: claims.org_id,
118 organization_membership_id: claims.organization_membership_id,
119 roles: claims.role.into_iter().collect(),
120 permissions: claims.permissions,
121 }
122 }
123}
124
125pub struct JwtAuthenticator {
127 config: OidcConfig,
128 http: Arc<dyn HttpClient>,
129 jwks: RwLock<Option<JwkSet>>,
130}
131
132impl JwtAuthenticator {
133 pub fn new(config: OidcConfig) -> Self {
135 Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
136 }
137
138 pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
140 Self {
141 config,
142 http,
143 jwks: RwLock::new(None),
144 }
145 }
146
147 pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
149 let data = self.decode_claims(token)?;
150 if !data.claims.aud.contains(&self.config.audience) {
151 return Err(JwtAuthError::InvalidAudience);
152 }
153 Ok(data.claims.into())
154 }
155
156 fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
157 let header = decode_header(token)?;
158 let jwks = self.jwks()?;
159 let key = match header.kid.as_deref() {
160 Some(kid) => jwks.find(kid).ok_or(JwtAuthError::MissingKey)?,
161 None => jwks.keys.first().ok_or(JwtAuthError::MissingKey)?,
162 };
163
164 let mut validation = Validation::new(header.alg);
165 validation.algorithms = self.config.algorithms.clone();
166 validation.set_issuer(&[self.config.issuer.as_str()]);
167 validation.set_audience(&[self.config.audience.as_str()]);
168
169 Ok(decode::<JwtClaims>(
170 token,
171 &DecodingKey::from_jwk(key)?,
172 &validation,
173 )?)
174 }
175
176 fn jwks(&self) -> Result<JwkSet, JwtAuthError> {
177 if let Some(jwks) = self.jwks.read().expect("jwks lock poisoned").clone() {
178 return Ok(jwks);
179 }
180
181 let value = self.http.get_json(&self.config.jwks_url, &[])?;
182 let jwks: JwkSet = serde_json::from_value(value)?;
183 *self.jwks.write().expect("jwks lock poisoned") = Some(jwks.clone());
184 Ok(jwks)
185 }
186}
187
188#[derive(Debug, thiserror::Error)]
190pub enum JwtAuthError {
191 #[error("jwt validation failed: {0}")]
193 Jwt(#[from] jsonwebtoken::errors::Error),
194 #[error("jwks fetch failed: {0}")]
196 Http(#[from] Box<dyn std::error::Error + Send + Sync>),
197 #[error("jwks parse failed: {0}")]
199 Json(#[from] serde_json::Error),
200 #[error("no matching signing key found in JWKS")]
202 MissingKey,
203 #[error("token audience did not match expected audience")]
205 InvalidAudience,
206}
207
208pub struct JwtClaimsEngine {
214 subject: String,
215 permissions: HashSet<String>,
216 org_id: Option<String>,
217}
218
219impl JwtClaimsEngine {
220 pub fn new(subject: VerifiedSubject) -> Self {
222 Self {
223 subject: subject.subject,
224 permissions: subject.permissions.into_iter().collect(),
225 org_id: subject.org_id,
226 }
227 }
228
229 pub fn from_permissions(
231 subject: impl Into<String>,
232 permissions: impl IntoIterator<Item = String>,
233 ) -> Self {
234 Self {
235 subject: subject.into(),
236 permissions: permissions.into_iter().collect(),
237 org_id: None,
238 }
239 }
240
241 fn permission_matches(&self, action: &str, resource: &str) -> bool {
242 if self.permissions.contains(action) {
243 return true;
244 }
245
246 let resource_type = resource.split(['/', ':']).next().unwrap_or(resource);
247 self.permissions
248 .contains(&format!("{resource_type}:{action}"))
249 }
250}
251
252impl PolicyEngine for JwtClaimsEngine {
253 fn check(&self, subject: &str, action: &str, resource: &str) -> PolicyResult {
254 debug!(subject, action, resource, org_id = ?self.org_id, "jwt claims check");
255
256 if subject != self.subject {
257 return PolicyResult::Delegate(format!(
258 "jwt claims are for '{}', not '{subject}'",
259 self.subject
260 ));
261 }
262
263 if self.permission_matches(action, resource) {
264 PolicyResult::Allow
265 } else {
266 PolicyResult::Delegate(format!("permission '{action}' not present in jwt claims"))
267 }
268 }
269}
270
271#[allow(dead_code)]
272fn _assert_value_send_sync(_: Value) {}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::http::StaticHttpClient;
278 use chrono::{Duration, Utc};
279 use jsonwebtoken::{EncodingKey, Header, encode};
280 use serde_json::json;
281
282 #[test]
283 fn jwt_claims_engine_allows_direct_permission() {
284 let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
285 assert_eq!(
286 engine.check("user_1", "read", "project/123"),
287 PolicyResult::Allow
288 );
289 }
290
291 #[test]
292 fn jwt_claims_engine_allows_resource_type_permission() {
293 let engine = JwtClaimsEngine::from_permissions("user_1", ["project:edit".to_string()]);
294 assert_eq!(
295 engine.check("user_1", "edit", "project/123"),
296 PolicyResult::Allow
297 );
298 }
299
300 #[test]
301 fn jwt_claims_engine_delegates_missing_permission() {
302 let engine = JwtClaimsEngine::from_permissions("user_1", ["read".to_string()]);
303 assert!(matches!(
304 engine.check("user_1", "write", "project/123"),
305 PolicyResult::Delegate(_)
306 ));
307 }
308
309 #[test]
310 fn jwt_authenticator_verifies_hs256_token_from_jwks() {
311 let jwks_url = "https://issuer.example/.well-known/jwks.json";
312 let http = StaticHttpClient::new().with_response(
313 jwks_url,
314 json!({
315 "keys": [{
316 "kty": "oct",
317 "kid": "test-key",
318 "alg": "HS256",
319 "k": "c2VjcmV0"
320 }]
321 }),
322 );
323 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
324 config.algorithms = vec![Algorithm::HS256];
325 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
326
327 let claims = JwtClaims {
328 sub: "user_123".to_string(),
329 iss: "https://issuer.example".to_string(),
330 aud: Audience::Single("typesec-test".to_string()),
331 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
332 org_id: Some("org_123".to_string()),
333 organization_membership_id: Some("om_123".to_string()),
334 role: Some("org_member".to_string()),
335 permissions: vec!["org:view".to_string(), "project:read".to_string()],
336 };
337 let mut header = Header::new(Algorithm::HS256);
338 header.kid = Some("test-key".to_string());
339 let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
340 .expect("token should encode");
341
342 let verified = auth.verify(&token).expect("token should verify");
343 assert_eq!(verified.subject, "user_123");
344 assert_eq!(verified.workos_membership_subject(), "om_123");
345 assert_eq!(verified.permissions, vec!["org:view", "project:read"]);
346 }
347
348 #[test]
349 fn jwt_authenticator_rejects_wrong_audience() {
350 let jwks_url = "https://issuer.example/.well-known/jwks.json";
351 let http = StaticHttpClient::new().with_response(
352 jwks_url,
353 json!({
354 "keys": [{
355 "kty": "oct",
356 "kid": "test-key",
357 "alg": "HS256",
358 "k": "c2VjcmV0"
359 }]
360 }),
361 );
362 let mut config = OidcConfig::new("https://issuer.example", "typesec-test", jwks_url);
363 config.algorithms = vec![Algorithm::HS256];
364 let auth = JwtAuthenticator::with_http(config, Arc::new(http));
365
366 let claims = JwtClaims {
367 sub: "user_123".to_string(),
368 iss: "https://issuer.example".to_string(),
369 aud: Audience::Single("other-audience".to_string()),
370 exp: (Utc::now() + Duration::minutes(10)).timestamp() as usize,
371 org_id: None,
372 organization_membership_id: None,
373 role: None,
374 permissions: vec![],
375 };
376 let mut header = Header::new(Algorithm::HS256);
377 header.kid = Some("test-key".to_string());
378 let token = encode(&header, &claims, &EncodingKey::from_secret(b"secret"))
379 .expect("token should encode");
380
381 assert!(auth.verify(&token).is_err());
382 }
383}