1#[cfg(feature = "verifier")]
4use crate::{
5 cache::Cache,
6 discovery::discover,
7 errors::{Error, Result},
8 http::HttpClient,
9 id_token::fetch_jwks,
10 jwks::Jwks,
11 types::VerifiedClaims,
12};
13
14#[cfg(feature = "verifier")]
15use base64::{engine::general_purpose, Engine as _};
16#[cfg(feature = "verifier")]
17use josekit::{
18 jws::RS256,
19 jwt::{self, JwtPayload},
20};
21#[cfg(feature = "verifier")]
22use std::collections::HashMap;
23#[cfg(feature = "verifier")]
24use std::time::{SystemTime, UNIX_EPOCH};
25
26#[cfg(feature = "verifier")]
28pub struct JwtVerifier<C: Cache<String, Jwks>, H: HttpClient> {
29 pub issuer_map: HashMap<String, String>,
31 pub audience: String,
33 pub http: std::sync::Arc<H>,
35 pub cache: std::sync::Arc<C>,
37 pub clock_skew_sec: i64,
39 pub default_issuer: Option<String>,
41}
42
43#[cfg(feature = "verifier")]
44impl<C: Cache<String, Jwks>, H: HttpClient> JwtVerifier<C, H> {
45 pub fn new(
47 issuer_map: HashMap<String, String>,
48 audience: String,
49 http: std::sync::Arc<H>,
50 cache: std::sync::Arc<C>,
51 ) -> Self {
52 Self { issuer_map, audience, http, cache, clock_skew_sec: 60, default_issuer: None }
53 }
54
55 pub fn builder() -> JwtVerifierBuilder<C, H> {
57 JwtVerifierBuilder::default()
58 }
59
60 pub async fn verify(&self, bearer: &str) -> Result<VerifiedClaims> {
62 let token = bearer.strip_prefix("Bearer ").unwrap_or(bearer);
64
65 let unverified_issuer = extract_unverified_issuer(token)?;
67
68 let expected_issuer = self.resolve_issuer(&unverified_issuer)?;
70
71 let metadata_cache = crate::cache::NoOpCache;
74 let metadata = discover(&expected_issuer, self.http.as_ref(), &metadata_cache).await?;
75 let jwks = fetch_jwks(&metadata.jwks_uri, self.http.as_ref(), self.cache.as_ref()).await?;
76
77 let kid = extract_kid(token)?.ok_or_else(|| Error::Jwt("Token missing kid".into()))?;
79
80 let jwk = jwks
81 .find_key(&kid)
82 .ok_or_else(|| Error::Jwt(format!("Key with kid '{}' not found", kid)))?;
83
84 let payload = verify_token_signature(token, jwk)?;
85
86 let claims = extract_and_validate_access_token_claims(
88 payload,
89 &expected_issuer,
90 &self.audience,
91 self.clock_skew_sec,
92 )?;
93
94 Ok(claims)
95 }
96
97 fn resolve_issuer(&self, token_issuer: &str) -> Result<String> {
99 if self.issuer_map.values().any(|v| v == token_issuer) {
101 return Ok(token_issuer.to_string());
102 }
103
104 if self.issuer_map.contains_key(token_issuer) {
107 return Ok(self.issuer_map[token_issuer].clone());
108 }
109
110 if let Some(default) = &self.default_issuer {
112 return Ok(default.clone());
113 }
114
115 Err(Error::Verification(format!("Issuer '{}' not in allowed list", token_issuer)))
117 }
118
119 pub fn resolve_issuer_with_tenant(&self, tenant: &str) -> Result<String> {
122 if let Some(issuer) = self.issuer_map.get(tenant) {
124 return Ok(issuer.clone());
125 }
126
127 if let Some(default) = &self.default_issuer {
129 return Ok(default.clone());
130 }
131
132 Err(Error::Verification(format!("No issuer configured for tenant '{}'", tenant)))
134 }
135}
136
137#[cfg(feature = "verifier")]
139pub struct JwtVerifierBuilder<C: Cache<String, Jwks>, H: HttpClient> {
140 issuer_map: Option<HashMap<String, String>>,
141 audience: Option<String>,
142 http: Option<std::sync::Arc<H>>,
143 cache: Option<std::sync::Arc<C>>,
144 clock_skew_sec: Option<i64>,
145 default_issuer: Option<String>,
146}
147
148#[cfg(feature = "verifier")]
149impl<C: Cache<String, Jwks>, H: HttpClient> Default for JwtVerifierBuilder<C, H> {
150 fn default() -> Self {
151 Self {
152 issuer_map: None,
153 audience: None,
154 http: None,
155 cache: None,
156 clock_skew_sec: None,
157 default_issuer: None,
158 }
159 }
160}
161
162#[cfg(feature = "verifier")]
163impl<C: Cache<String, Jwks>, H: HttpClient> JwtVerifierBuilder<C, H> {
164 pub fn issuer_map(mut self, map: HashMap<String, String>) -> Self {
166 self.issuer_map = Some(map);
167 self
168 }
169
170 pub fn audience(mut self, audience: impl Into<String>) -> Self {
172 self.audience = Some(audience.into());
173 self
174 }
175
176 pub fn http(mut self, http: std::sync::Arc<H>) -> Self {
178 self.http = Some(http);
179 self
180 }
181
182 pub fn cache(mut self, cache: std::sync::Arc<C>) -> Self {
184 self.cache = Some(cache);
185 self
186 }
187
188 pub fn clock_skew(mut self, seconds: i64) -> Self {
190 self.clock_skew_sec = Some(seconds);
191 self
192 }
193
194 pub fn default_issuer(mut self, issuer: impl Into<String>) -> Self {
196 self.default_issuer = Some(issuer.into());
197 self
198 }
199
200 pub fn build(self) -> Result<JwtVerifier<C, H>> {
202 Ok(JwtVerifier {
203 issuer_map: self.issuer_map.unwrap_or_default(),
204 audience: self.audience.ok_or(Error::MissingConfig("audience"))?,
205 http: self.http.ok_or(Error::MissingConfig("http client"))?,
206 cache: self.cache.ok_or(Error::MissingConfig("cache"))?,
207 clock_skew_sec: self.clock_skew_sec.unwrap_or(60),
208 default_issuer: self.default_issuer,
209 })
210 }
211}
212
213#[cfg(feature = "verifier")]
215fn extract_kid(jwt: &str) -> Result<Option<String>> {
216 let parts: Vec<&str> = jwt.split('.').collect();
217 if parts.len() != 3 {
218 return Err(Error::Jwt("Invalid JWT format".into()));
219 }
220
221 let header_bytes = general_purpose::URL_SAFE_NO_PAD
222 .decode(parts[0])
223 .map_err(|e| Error::Base64(format!("Failed to decode header: {}", e)))?;
224
225 let header_value: serde_json::Value = serde_json::from_slice(&header_bytes)
226 .map_err(|e| Error::Jwt(format!("Failed to parse header JSON: {}", e)))?;
227
228 Ok(header_value.get("kid").and_then(|v| v.as_str()).map(|s| s.to_string()))
229}
230
231#[cfg(feature = "verifier")]
233fn extract_unverified_issuer(token: &str) -> Result<String> {
234 let parts: Vec<&str> = token.split('.').collect();
235 if parts.len() != 3 {
236 return Err(Error::Jwt("Invalid JWT format".into()));
237 }
238
239 let payload_json = general_purpose::URL_SAFE_NO_PAD
240 .decode(parts[1])
241 .map_err(|e| Error::Base64(e.to_string()))?;
242
243 let payload: serde_json::Value = serde_json::from_slice(&payload_json)?;
244
245 payload["iss"]
246 .as_str()
247 .ok_or_else(|| Error::Jwt("Token missing issuer".into()))
248 .map(|s| s.to_string())
249}
250
251#[cfg(feature = "verifier")]
253fn verify_token_signature(token: &str, jwk: &crate::jwks::Jwk) -> Result<JwtPayload> {
254 let key = josekit::jwk::Jwk::from_map(serde_json::to_value(jwk)?.as_object().unwrap().clone())
256 .map_err(|e| Error::Jwt(format!("Invalid JWK: {}", e)))?;
257
258 let alg = jwk.alg.as_deref().unwrap_or("RS256");
260 let verifier = match alg {
261 "RS256" => RS256.verifier_from_jwk(&key),
262 alg => return Err(Error::Jwt(format!("Unsupported algorithm: {}", alg))),
263 }
264 .map_err(|e| Error::Jwt(format!("Failed to create verifier: {}", e)))?;
265
266 let (payload, _header) = jwt::decode_with_verifier(token, &verifier)
267 .map_err(|e| Error::Jwt(format!("Token verification failed: {}", e)))?;
268
269 Ok(payload)
270}
271
272#[cfg(feature = "verifier")]
274fn extract_and_validate_access_token_claims(
275 payload: JwtPayload,
276 expected_issuer: &str,
277 expected_audience: &str,
278 clock_skew: i64,
279) -> Result<VerifiedClaims> {
280 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
281
282 let iss = payload.issuer().ok_or_else(|| Error::Verification("Missing iss claim".into()))?;
284 let sub = payload.subject().ok_or_else(|| Error::Verification("Missing sub claim".into()))?;
285 let exp = payload
286 .expires_at()
287 .ok_or_else(|| Error::Verification("Missing exp claim".into()))?
288 .duration_since(UNIX_EPOCH)
289 .map_err(|_| Error::Verification("Invalid exp time".into()))?
290 .as_secs() as i64;
291 let iat = payload
292 .issued_at()
293 .ok_or_else(|| Error::Verification("Missing iat claim".into()))?
294 .duration_since(UNIX_EPOCH)
295 .map_err(|_| Error::Verification("Invalid iat time".into()))?
296 .as_secs() as i64;
297
298 if iss != expected_issuer {
300 return Err(Error::Verification(format!(
301 "Invalid issuer: expected '{}', got '{}'",
302 expected_issuer, iss
303 )));
304 }
305
306 let aud = if let Some(audiences) = payload.audience() {
308 if !audiences.iter().any(|a| *a == expected_audience) {
309 return Err(Error::Verification(format!(
310 "Invalid audience: expected '{}'",
311 expected_audience
312 )));
313 }
314 expected_audience.to_string()
315 } else {
316 return Err(Error::Verification("Missing aud claim".into()));
317 };
318
319 if exp < now - clock_skew {
321 return Err(Error::Verification("Token expired".into()));
322 }
323
324 if iat > now + clock_skew {
326 return Err(Error::Verification("Token issued in the future".into()));
327 }
328
329 let claims_map = payload.claims_set();
331
332 let jti = claims_map.get("jti").and_then(|v| v.as_str()).unwrap_or("").to_string();
333
334 let scope = claims_map.get("scope").and_then(|v| v.as_str()).map(|s| s.to_string());
335
336 let xjp_admin = claims_map.get("xjp_admin").and_then(|v| v.as_bool());
337
338 let amr = claims_map.get("amr").and_then(|v| {
339 v.as_array()?
340 .iter()
341 .map(|item| item.as_str().map(|s| s.to_string()))
342 .collect::<Option<Vec<String>>>()
343 });
344
345 let auth_time = claims_map.get("auth_time").and_then(|v| v.as_i64());
346
347 Ok(VerifiedClaims {
348 iss: iss.to_string(),
349 sub: sub.to_string(),
350 aud: aud.to_string(),
351 exp,
352 iat,
353 jti,
354 scope,
355 xjp_admin,
356 amr,
357 auth_time,
358 })
359}
360
361#[cfg(all(test, feature = "verifier"))]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_extract_unverified_issuer() {
367 let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5In0.eyJpc3MiOiJodHRwczovL2F1dGguZXhhbXBsZS5jb20ifQ.dummy";
369
370 let issuer = extract_unverified_issuer(token).unwrap();
371 assert_eq!(issuer, "https://auth.example.com");
372 }
373}