salvo_jwt_auth/oidc/
mod.rs

1//! Oidc(OpenID Connect) supports.
2
3use std::fmt::{self, Debug, Formatter};
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::SystemTime;
7use std::time::UNIX_EPOCH;
8
9use base64::Engine;
10use base64::engine::general_purpose::URL_SAFE_NO_PAD;
11use bytes::Bytes;
12use http_body_util::{BodyExt, Full};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
14use hyper_util::client::legacy::{Client, connect::HttpConnector};
15use hyper_util::rt::TokioExecutor;
16use jsonwebtoken::jwk::{Jwk, JwkSet};
17use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
18use salvo_core::Depot;
19use salvo_core::http::{header::CACHE_CONTROL, uri::Uri};
20use serde::Deserialize;
21use serde::de::DeserializeOwned;
22use tokio::sync::{Notify, RwLock};
23
24use super::{JwtAuthDecoder, JwtAuthError};
25
26mod cache;
27
28pub use cache::{CachePolicy, CacheState, JwkSetStore, UpdateAction};
29
30pub(super) type HyperClient = Client<HttpsConnector<HttpConnector>, Full<Bytes>>;
31
32/// ConstDecoder will decode token with a static secret.
33#[derive(Clone)]
34pub struct OidcDecoder {
35    issuer: String,
36    http_client: HyperClient,
37    cache: Arc<RwLock<JwkSetStore>>,
38    cache_state: Arc<CacheState>,
39    notifier: Arc<Notify>,
40}
41
42impl Debug for OidcDecoder {
43    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
44        f.debug_struct("OidcDecoder")
45            .field("issuer", &self.issuer)
46            .field("http_client", &self.http_client)
47            .field("cache_state", &self.cache_state)
48            .finish()
49    }
50}
51
52impl JwtAuthDecoder for OidcDecoder {
53    type Error = JwtAuthError;
54
55    /// Validates a JWT, Returning the claims serialized into type of T
56    async fn decode<C>(&self, token: &str, _depot: &mut Depot) -> Result<TokenData<C>, Self::Error>
57    where
58        C: DeserializeOwned + Clone,
59    {
60        // Early return error conditions before acquiring a read lock
61        let header = jsonwebtoken::decode_header(token)?;
62        let kid = header.kid.ok_or(JwtAuthError::MissingKid)?;
63
64        let decoding_key = self.get_kid_retry(kid).await?;
65        let decoded = decoding_key.decode(token)?;
66        Ok(decoded)
67    }
68}
69
70/// A builder for `OidcDecoder`.
71pub struct DecoderBuilder<T>
72where
73    T: AsRef<str>,
74{
75    /// The issuer URL of the token. eg: `https://xx-xx.clerk.accounts.dev`
76    pub issuer: T,
77    /// The http client for the decoder.
78    pub http_client: Option<HyperClient>,
79    /// The validation options for the decoder.
80    pub validation: Option<Validation>,
81}
82impl<T: AsRef<str>> Debug for DecoderBuilder<T> {
83    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
84        f.debug_struct("DecoderBuilder")
85            .field("http_client", &self.http_client)
86            .field("validation", &self.validation)
87            .finish()
88    }
89}
90impl<T> DecoderBuilder<T>
91where
92    T: AsRef<str>,
93{
94    /// Create a new `DecoderBuilder`.
95    #[must_use]
96    pub fn new(issuer: T) -> Self {
97        Self {
98            issuer,
99            http_client: None,
100            validation: None,
101        }
102    }
103    /// Set the http client for the decoder.
104    #[must_use]
105    pub fn http_client(mut self, client: HyperClient) -> Self {
106        self.http_client = Some(client);
107        self
108    }
109    /// Set the validation options for the decoder.
110    #[must_use]
111    pub fn validation(mut self, validation: Validation) -> Self {
112        self.validation = Some(validation);
113        self
114    }
115
116    /// Build a `OidcDecoder`.
117    pub fn build(self) -> impl Future<Output = Result<OidcDecoder, JwtAuthError>> {
118        let Self {
119            issuer,
120            http_client,
121            validation,
122        } = self;
123        let issuer = issuer.as_ref().trim_end_matches('/').to_owned();
124
125        //Create an empty JWKS to initialize our Cache
126        let jwks = JwkSet { keys: Vec::new() };
127
128        let validation = validation.unwrap_or_default();
129        let cache = Arc::new(RwLock::new(JwkSetStore::new(
130            jwks,
131            CachePolicy::default(),
132            validation,
133        )));
134        let cache_state = Arc::new(CacheState::new());
135
136        let https = HttpsConnectorBuilder::new()
137            .with_native_roots()
138            .expect("no native root CA certificates found")
139            .https_only()
140            .enable_http1()
141            .build();
142        let http_client =
143            http_client.unwrap_or_else(|| Client::builder(TokioExecutor::new()).build(https));
144        let decoder = OidcDecoder {
145            issuer,
146            http_client,
147            cache,
148            cache_state,
149            notifier: Arc::new(Notify::new()),
150        };
151        async move {
152            decoder.update_cache().await?;
153            Ok(decoder)
154        }
155    }
156}
157
158impl OidcDecoder {
159    /// Create a new `OidcDecoder`.
160    pub fn new(issuer: impl AsRef<str>) -> impl Future<Output = Result<Self, JwtAuthError>> {
161        Self::builder(issuer).build()
162    }
163
164    /// Create a new `DecoderBuilder`.
165    pub fn builder<T>(issuer: T) -> DecoderBuilder<T>
166    where
167        T: AsRef<str>,
168    {
169        DecoderBuilder::new(issuer)
170    }
171
172    fn config_url(&self) -> String {
173        format!("{}/.well-known/openid-configuration", &self.issuer)
174    }
175    async fn get_config(&self) -> Result<OidcConfig, JwtAuthError> {
176        let res = self
177            .http_client
178            .get(self.config_url().parse::<Uri>()?)
179            .await?;
180        let body = res.into_body().collect().await?.to_bytes();
181        let config = serde_json::from_slice(&body)?;
182        Ok(config)
183    }
184    async fn jwks_uri(&self) -> Result<String, JwtAuthError> {
185        Ok(self.get_config().await?.jwks_uri)
186    }
187
188    /// Triggers an HTTP Request to get a fresh `JwkSet`
189    async fn get_jwks(&self) -> Result<JwkSetFetch, JwtAuthError> {
190        let uri = self.jwks_uri().await?.parse::<Uri>()?;
191        // Get the jwks endpoint
192        tracing::debug!("Requesting JWKS From Uri: {uri}");
193        let res = self.http_client.get(uri).await?;
194
195        let cache_policy = {
196            // Determine it from the cache_control header
197            let cache_control = res.headers().get(CACHE_CONTROL);
198            let cache_policy = CachePolicy::from_header_val(cache_control);
199            Some(cache_policy)
200        };
201        let jwks = res.into_body().collect().await?.to_bytes();
202
203        let fetched_at = current_time();
204        Ok(JwkSetFetch {
205            jwks: serde_json::from_slice(&jwks)?,
206            cache_policy,
207            fetched_at,
208        })
209    }
210
211    /// Triggers an immediate update from the JWKS URL
212    /// Will only write lock the [`JwkSetStore`] if there is an actual change to the contents.
213    async fn update_cache(&self) -> Result<UpdateAction, JwtAuthError> {
214        let fetch = self.get_jwks().await;
215        match fetch {
216            Ok(fetch) => {
217                self.cache_state.set_last_update(fetch.fetched_at);
218                tracing::info!("Set Last update to {:#?}", fetch.fetched_at);
219                self.cache_state.set_is_error(false);
220                let read = self.cache.read().await;
221
222                if read.jwks == fetch.jwks
223                    && fetch.cache_policy.unwrap_or(read.cache_policy) == read.cache_policy
224                {
225                    return Ok(UpdateAction::NoUpdate);
226                }
227                drop(read);
228                let mut write = self.cache.write().await;
229
230                Ok(write.update_fetch(fetch))
231            }
232            Err(e) => {
233                self.cache_state.set_is_error(true);
234                Err(e)
235            }
236        }
237    }
238    /// Triggers an eventual update from the JWKS URL
239    /// Will only ever spawn one task at a single time.
240    /// If called while an update task is currently running, will do nothing.
241    fn revalidate_cache(&self) {
242        if !self.cache_state.is_revalidating() {
243            self.cache_state.set_is_revalidating(true);
244            tracing::info!("Spawning Task to re-validate JWKS");
245            let a = self.clone();
246            tokio::task::spawn(async move {
247                let _ = a.update_cache().await;
248                a.cache_state.set_is_revalidating(false);
249                a.notifier.notify_waiters();
250            });
251        }
252    }
253
254    /// If we are currently updating the JWKS in the background this function will resolve when the update it complete
255    /// If we are not currently updating the JWKS in the background, this function will resolve immediately.
256    async fn wait_update(&self) {
257        if self.cache_state.is_revalidating() {
258            self.notifier.notified().await;
259        }
260    }
261
262    /// Primary method for getting the [`DecodingInfo`] for a JWK needed to validate a JWT.
263    /// If the kid was not present in [`JwkSetStore`]
264    #[allow(clippy::future_not_send)]
265    async fn get_kid_retry(&self, kid: impl AsRef<str>) -> Result<Arc<DecodingInfo>, JwtAuthError> {
266        let kid = kid.as_ref();
267        // Check to see if we have the kid
268        if let Ok(Some(key)) = self.get_kid(kid).await {
269            // if we have it, then return it
270            Ok(key)
271        } else {
272            // Try and invalidate our cache. Maybe the JWKS has changed or our cached values expired
273            // Even if it failed it. It may allow us to retrieve a key from stale-if-error
274            self.revalidate_cache();
275            self.wait_update().await;
276            self.get_kid(kid).await?.ok_or(JwtAuthError::CacheError)
277        }
278    }
279
280    /// Gets the decoding components of a JWK by kid from the JWKS in our cache
281    /// Returns an Error, if the cache is stale and beyond the Stale While Revalidate and Stale If Error allowances configured in [`crate::cache::Settings`]
282    /// Returns Ok if the cache is not stale.
283    /// Returns Ok after triggering a background update of the JWKS If the cache is stale but within the Stale While Revalidate and Stale If Error allowances.
284    #[allow(clippy::future_not_send)]
285    async fn get_kid(&self, kid: &str) -> Result<Option<Arc<DecodingInfo>>, JwtAuthError> {
286        let read_cache = self.cache.read().await;
287        let fetched = self.cache_state.last_update();
288        let max_age_secs = read_cache.cache_policy.max_age.as_secs();
289
290        let max_age = fetched + max_age_secs;
291        let now = current_time();
292        let val = read_cache.get_key(kid);
293
294        if now <= max_age {
295            return Ok(val);
296        }
297
298        // If the stale while revalidate setting is present
299        if let Some(swr) = read_cache.cache_policy.stale_while_revalidate {
300            // if we're within the SWR allowed window
301            if now <= swr.as_secs() + max_age {
302                self.revalidate_cache();
303                return Ok(val);
304            }
305        }
306        if let Some(swr_err) = read_cache.cache_policy.stale_if_error {
307            // if the last update failed and the stale-if-error is present
308            if now <= swr_err.as_secs() + max_age && self.cache_state.is_error() {
309                self.revalidate_cache();
310                return Ok(val);
311            }
312        }
313        drop(read_cache);
314        tracing::info!("Returning None: {now} - {max_age}");
315        Err(JwtAuthError::CacheError)
316    }
317}
318
319/// Struct used to store the computed information needed to decode a JWT
320/// Intended to be cached inside of [`JwkSetStore`] to prevent decoding information about the same JWK more than once
321pub struct DecodingInfo {
322    // jwk: Jwk,
323    key: DecodingKey,
324    validation: Validation,
325    // alg: Algorithm,
326}
327impl Debug for DecodingInfo {
328    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
329        f.debug_struct("DecodingInfo")
330            .field("validation", &self.validation)
331            .finish()
332    }
333}
334impl DecodingInfo {
335    fn new(key: DecodingKey, alg: Algorithm, validation_settings: &Validation) -> Self {
336        let mut validation = Validation::new(alg);
337
338        validation.aud.clone_from(&validation_settings.aud);
339        validation.iss.clone_from(&validation_settings.iss);
340        validation.leeway = validation_settings.leeway;
341        validation
342            .required_spec_claims
343            .clone_from(&validation_settings.required_spec_claims);
344
345        validation.sub.clone_from(&validation_settings.sub);
346        validation.validate_exp = validation_settings.validate_exp;
347        validation.validate_nbf = validation_settings.validate_nbf;
348
349        Self {
350            // jwk,
351            key,
352            validation,
353            // alg,
354        }
355    }
356
357    fn decode<T>(&self, token: &str) -> Result<TokenData<T>, JwtAuthError>
358    where
359        T: for<'de> serde::de::Deserialize<'de> + Clone,
360    {
361        match jsonwebtoken::decode::<T>(token, &self.key, &self.validation) {
362            Ok(data) => Ok(data),
363            Err(e) => {
364                tracing::error!(error = ?e, token, "error decoding jwt token");
365                Err(JwtAuthError::from(e))
366            }
367        }
368    }
369}
370
371/// Helper Struct that contains the response of a request to the jwks uri
372/// `cache_policy` will be Some when [`cache::Strategy`] is set to [`cache::Strategy::Automatic`].
373#[derive(Debug)]
374pub(crate) struct JwkSetFetch {
375    jwks: JwkSet,
376    cache_policy: Option<CachePolicy>,
377    fetched_at: u64,
378}
379
380#[derive(Debug, Deserialize)]
381struct OidcConfig {
382    jwks_uri: String,
383}
384
385pub(crate) fn decode_jwk(
386    jwk: &Jwk,
387    validation: &Validation,
388) -> Result<(String, DecodingInfo), JwtAuthError> {
389    let kid = jwk.common.key_id.clone();
390    let alg = jwk.common.key_algorithm;
391
392    let dec_key = match jwk.algorithm {
393        jsonwebtoken::jwk::AlgorithmParameters::EllipticCurve(ref params) => {
394            let x_cmp = b64_decode(&params.x)?;
395            let y_cmp = b64_decode(&params.y)?;
396            let mut public_key = Vec::with_capacity(1 + params.x.len() + params.y.len());
397            public_key.push(0x04);
398            public_key.extend_from_slice(&x_cmp);
399            public_key.extend_from_slice(&y_cmp);
400            Some(DecodingKey::from_ec_der(&public_key))
401        }
402        jsonwebtoken::jwk::AlgorithmParameters::RSA(ref params) => {
403            DecodingKey::from_rsa_components(&params.n, &params.e).ok()
404        }
405        jsonwebtoken::jwk::AlgorithmParameters::OctetKey(ref params) => {
406            DecodingKey::from_base64_secret(&params.value).ok()
407        }
408        jsonwebtoken::jwk::AlgorithmParameters::OctetKeyPair(ref params) => {
409            let der = b64_decode(&params.x)?;
410
411            Some(DecodingKey::from_ed_der(&der))
412        }
413    };
414    match (kid, alg, dec_key) {
415        (Some(kid), Some(alg), Some(dec_key)) => {
416            let alg = Algorithm::from_str(alg.to_string().as_str())?;
417            let info = DecodingInfo::new(dec_key, alg, validation);
418            Ok((kid, info))
419        }
420        _ => Err(JwtAuthError::InvalidJwk),
421    }
422}
423
424fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
425    URL_SAFE_NO_PAD.decode(input.as_ref())
426}
427
428pub(crate) fn current_time() -> u64 {
429    SystemTime::now()
430        .duration_since(UNIX_EPOCH)
431        .expect("Time Went Backwards")
432        .as_secs()
433}
434
435#[cfg(test)]
436mod tests {
437    use serde_json::json;
438
439    use super::*;
440
441    #[test]
442    fn test_decode_jwk_missing_alg() {
443        let jwk_json = json!({
444            "kty": "RSA",
445            "kid": "test-rsa",
446            "n": "...",
447            "e": "AQAB"
448        });
449        let jwk: Jwk = serde_json::from_value(jwk_json).unwrap();
450        let validation = Validation::new(Algorithm::RS256);
451        let result = decode_jwk(&jwk, &validation);
452        assert!(result.is_err());
453    }
454}