1use std::fmt::{self, Debug, Formatter};
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::SystemTime;
7use std::time::UNIX_EPOCH;
8
9use base64::Engine;
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use bytes::Bytes;
12use http_body_util::{BodyExt, Full};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
14use hyper_util::client::legacy::{Client, connect::HttpConnector};
15use hyper_util::rt::TokioExecutor;
16use jsonwebtoken::jwk::{Jwk, JwkSet};
17use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
18use salvo_core::Depot;
19use salvo_core::http::{header::CACHE_CONTROL, uri::Uri};
20use serde::Deserialize;
21use serde::de::DeserializeOwned;
22use tokio::sync::{Notify, RwLock};
23
24use super::{JwtAuthDecoder, JwtAuthError};
25
26mod cache;
27
28pub use cache::{CachePolicy, CacheState, JwkSetStore, UpdateAction};
29
30pub(super) type HyperClient = Client<HttpsConnector<HttpConnector>, Full<Bytes>>;
31
32#[derive(Clone)]
34pub struct OidcDecoder {
35 issuer: String,
36 http_client: HyperClient,
37 cache: Arc<RwLock<JwkSetStore>>,
38 cache_state: Arc<CacheState>,
39 notifier: Arc<Notify>,
40}
41
42impl Debug for OidcDecoder {
43 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
44 f.debug_struct("OidcDecoder")
45 .field("issuer", &self.issuer)
46 .field("http_client", &self.http_client)
47 .field("cache_state", &self.cache_state)
48 .finish()
49 }
50}
51
52impl JwtAuthDecoder for OidcDecoder {
53 type Error = JwtAuthError;
54
55 async fn decode<C>(&self, token: &str, _depot: &mut Depot) -> Result<TokenData<C>, Self::Error>
57 where
58 C: DeserializeOwned + Clone,
59 {
60 let header = jsonwebtoken::decode_header(token)?;
62 let kid = header.kid.ok_or(JwtAuthError::MissingKid)?;
63
64 let decoding_key = self.get_kid_retry(kid).await?;
65 let decoded = decoding_key.decode(token)?;
66 Ok(decoded)
67 }
68}
69
70pub struct DecoderBuilder<T>
72where
73 T: AsRef<str>,
74{
75 pub issuer: T,
77 pub http_client: Option<HyperClient>,
79 pub validation: Option<Validation>,
81}
82impl<T: AsRef<str>> Debug for DecoderBuilder<T> {
83 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
84 f.debug_struct("DecoderBuilder")
85 .field("http_client", &self.http_client)
86 .field("validation", &self.validation)
87 .finish()
88 }
89}
90impl<T> DecoderBuilder<T>
91where
92 T: AsRef<str>,
93{
94 #[must_use]
96 pub fn new(issuer: T) -> Self {
97 Self {
98 issuer,
99 http_client: None,
100 validation: None,
101 }
102 }
103 #[must_use]
105 pub fn http_client(mut self, client: HyperClient) -> Self {
106 self.http_client = Some(client);
107 self
108 }
109 #[must_use]
111 pub fn validation(mut self, validation: Validation) -> Self {
112 self.validation = Some(validation);
113 self
114 }
115
116 pub fn build(self) -> impl Future<Output = Result<OidcDecoder, JwtAuthError>> {
118 let Self {
119 issuer,
120 http_client,
121 validation,
122 } = self;
123 let issuer = issuer.as_ref().trim_end_matches('/').to_owned();
124
125 let jwks = JwkSet { keys: Vec::new() };
127
128 let validation = validation.unwrap_or_default();
129 let cache = Arc::new(RwLock::new(JwkSetStore::new(
130 jwks,
131 CachePolicy::default(),
132 validation,
133 )));
134 let cache_state = Arc::new(CacheState::new());
135
136 let https = HttpsConnectorBuilder::new()
137 .with_native_roots()
138 .expect("no native root CA certificates found")
139 .https_only()
140 .enable_http1()
141 .build();
142 let http_client =
143 http_client.unwrap_or_else(|| Client::builder(TokioExecutor::new()).build(https));
144 let decoder = OidcDecoder {
145 issuer,
146 http_client,
147 cache,
148 cache_state,
149 notifier: Arc::new(Notify::new()),
150 };
151 async move {
152 decoder.update_cache().await?;
153 Ok(decoder)
154 }
155 }
156}
157
158impl OidcDecoder {
159 pub fn new(issuer: impl AsRef<str>) -> impl Future<Output = Result<Self, JwtAuthError>> {
161 Self::builder(issuer).build()
162 }
163
164 pub fn builder<T>(issuer: T) -> DecoderBuilder<T>
166 where
167 T: AsRef<str>,
168 {
169 DecoderBuilder::new(issuer)
170 }
171
172 fn config_url(&self) -> String {
173 format!("{}/.well-known/openid-configuration", &self.issuer)
174 }
175 async fn get_config(&self) -> Result<OidcConfig, JwtAuthError> {
176 let res = self
177 .http_client
178 .get(self.config_url().parse::<Uri>()?)
179 .await?;
180 let body = res.into_body().collect().await?.to_bytes();
181 let config = serde_json::from_slice(&body)?;
182 Ok(config)
183 }
184 async fn jwks_uri(&self) -> Result<String, JwtAuthError> {
185 Ok(self.get_config().await?.jwks_uri)
186 }
187
188 async fn get_jwks(&self) -> Result<JwkSetFetch, JwtAuthError> {
190 let uri = self.jwks_uri().await?.parse::<Uri>()?;
191 tracing::debug!("Requesting JWKS From Uri: {uri}");
193 let res = self.http_client.get(uri).await?;
194
195 let cache_policy = {
196 let cache_control = res.headers().get(CACHE_CONTROL);
198 let cache_policy = CachePolicy::from_header_val(cache_control);
199 Some(cache_policy)
200 };
201 let jwks = res.into_body().collect().await?.to_bytes();
202
203 let fetched_at = current_time();
204 Ok(JwkSetFetch {
205 jwks: serde_json::from_slice(&jwks)?,
206 cache_policy,
207 fetched_at,
208 })
209 }
210
211 async fn update_cache(&self) -> Result<UpdateAction, JwtAuthError> {
214 let fetch = self.get_jwks().await;
215 match fetch {
216 Ok(fetch) => {
217 self.cache_state.set_last_update(fetch.fetched_at);
218 tracing::info!("Set Last update to {:#?}", fetch.fetched_at);
219 self.cache_state.set_is_error(false);
220 let read = self.cache.read().await;
221
222 if read.jwks == fetch.jwks
223 && fetch.cache_policy.unwrap_or(read.cache_policy) == read.cache_policy
224 {
225 return Ok(UpdateAction::NoUpdate);
226 }
227 drop(read);
228 let mut write = self.cache.write().await;
229
230 Ok(write.update_fetch(fetch))
231 }
232 Err(e) => {
233 self.cache_state.set_is_error(true);
234 Err(e)
235 }
236 }
237 }
238 fn revalidate_cache(&self) {
242 if !self.cache_state.is_revalidating() {
243 self.cache_state.set_is_revalidating(true);
244 tracing::info!("Spawning Task to re-validate JWKS");
245 let a = self.clone();
246 tokio::task::spawn(async move {
247 let _ = a.update_cache().await;
248 a.cache_state.set_is_revalidating(false);
249 a.notifier.notify_waiters();
250 });
251 }
252 }
253
254 async fn wait_update(&self) {
257 if self.cache_state.is_revalidating() {
258 self.notifier.notified().await;
259 }
260 }
261
262 #[allow(clippy::future_not_send)]
265 async fn get_kid_retry(&self, kid: impl AsRef<str>) -> Result<Arc<DecodingInfo>, JwtAuthError> {
266 let kid = kid.as_ref();
267 if let Ok(Some(key)) = self.get_kid(kid).await {
269 Ok(key)
271 } else {
272 self.revalidate_cache();
275 self.wait_update().await;
276 self.get_kid(kid).await?.ok_or(JwtAuthError::CacheError)
277 }
278 }
279
280 #[allow(clippy::future_not_send)]
285 async fn get_kid(&self, kid: &str) -> Result<Option<Arc<DecodingInfo>>, JwtAuthError> {
286 let read_cache = self.cache.read().await;
287 let fetched = self.cache_state.last_update();
288 let max_age_secs = read_cache.cache_policy.max_age.as_secs();
289
290 let max_age = fetched + max_age_secs;
291 let now = current_time();
292 let val = read_cache.get_key(kid);
293
294 if now <= max_age {
295 return Ok(val);
296 }
297
298 if let Some(swr) = read_cache.cache_policy.stale_while_revalidate {
300 if now <= swr.as_secs() + max_age {
302 self.revalidate_cache();
303 return Ok(val);
304 }
305 }
306 if let Some(swr_err) = read_cache.cache_policy.stale_if_error {
307 if now <= swr_err.as_secs() + max_age && self.cache_state.is_error() {
309 self.revalidate_cache();
310 return Ok(val);
311 }
312 }
313 drop(read_cache);
314 tracing::info!("Returning None: {now} - {max_age}");
315 Err(JwtAuthError::CacheError)
316 }
317}
318
319pub struct DecodingInfo {
322 key: DecodingKey,
324 validation: Validation,
325 }
327impl Debug for DecodingInfo {
328 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
329 f.debug_struct("DecodingInfo")
330 .field("validation", &self.validation)
331 .finish()
332 }
333}
334impl DecodingInfo {
335 fn new(key: DecodingKey, alg: Algorithm, validation_settings: &Validation) -> Self {
336 let mut validation = Validation::new(alg);
337
338 validation.aud.clone_from(&validation_settings.aud);
339 validation.iss.clone_from(&validation_settings.iss);
340 validation.leeway = validation_settings.leeway;
341 validation
342 .required_spec_claims
343 .clone_from(&validation_settings.required_spec_claims);
344
345 validation.sub.clone_from(&validation_settings.sub);
346 validation.validate_exp = validation_settings.validate_exp;
347 validation.validate_nbf = validation_settings.validate_nbf;
348
349 Self {
350 key,
352 validation,
353 }
355 }
356
357 fn decode<T>(&self, token: &str) -> Result<TokenData<T>, JwtAuthError>
358 where
359 T: for<'de> serde::de::Deserialize<'de> + Clone,
360 {
361 match jsonwebtoken::decode::<T>(token, &self.key, &self.validation) {
362 Ok(data) => Ok(data),
363 Err(e) => {
364 tracing::error!(error = ?e, token, "error decoding jwt token");
365 Err(JwtAuthError::from(e))
366 }
367 }
368 }
369}
370
371#[derive(Debug)]
374pub(crate) struct JwkSetFetch {
375 jwks: JwkSet,
376 cache_policy: Option<CachePolicy>,
377 fetched_at: u64,
378}
379
380#[derive(Debug, Deserialize)]
381struct OidcConfig {
382 jwks_uri: String,
383}
384
385pub(crate) fn decode_jwk(
386 jwk: &Jwk,
387 validation: &Validation,
388) -> Result<(String, DecodingInfo), JwtAuthError> {
389 let kid = jwk.common.key_id.clone();
390 let alg = jwk.common.key_algorithm;
391
392 let dec_key = match jwk.algorithm {
393 jsonwebtoken::jwk::AlgorithmParameters::EllipticCurve(ref params) => {
394 let x_cmp = b64_decode(¶ms.x)?;
395 let y_cmp = b64_decode(¶ms.y)?;
396 let mut public_key = Vec::with_capacity(1 + params.x.len() + params.y.len());
397 public_key.push(0x04);
398 public_key.extend_from_slice(&x_cmp);
399 public_key.extend_from_slice(&y_cmp);
400 Some(DecodingKey::from_ec_der(&public_key))
401 }
402 jsonwebtoken::jwk::AlgorithmParameters::RSA(ref params) => {
403 DecodingKey::from_rsa_components(¶ms.n, ¶ms.e).ok()
404 }
405 jsonwebtoken::jwk::AlgorithmParameters::OctetKey(ref params) => {
406 DecodingKey::from_base64_secret(¶ms.value).ok()
407 }
408 jsonwebtoken::jwk::AlgorithmParameters::OctetKeyPair(ref params) => {
409 let der = b64_decode(¶ms.x)?;
410
411 Some(DecodingKey::from_ed_der(&der))
412 }
413 };
414 match (kid, alg, dec_key) {
415 (Some(kid), Some(alg), Some(dec_key)) => {
416 let alg = Algorithm::from_str(alg.to_string().as_str())?;
417 let info = DecodingInfo::new(dec_key, alg, validation);
418 Ok((kid, info))
419 }
420 _ => Err(JwtAuthError::InvalidJwk),
421 }
422}
423
424fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
425 URL_SAFE_NO_PAD.decode(input.as_ref())
426}
427
428pub(crate) fn current_time() -> u64 {
429 SystemTime::now()
430 .duration_since(UNIX_EPOCH)
431 .expect("Time Went Backwards")
432 .as_secs()
433}
434
435#[cfg(test)]
436mod tests {
437 use serde_json::json;
438
439 use super::*;
440
441 #[test]
442 fn test_decode_jwk_missing_alg() {
443 let jwk_json = json!({
444 "kty": "RSA",
445 "kid": "test-rsa",
446 "n": "...",
447 "e": "AQAB"
448 });
449 let jwk: Jwk = serde_json::from_value(jwk_json).unwrap();
450 let validation = Validation::new(Algorithm::RS256);
451 let result = decode_jwk(&jwk, &validation);
452 assert!(result.is_err());
453 }
454}