rusttls_jwt_authorizer/
authorizer.rs

1use std::{io::Read, sync::Arc};
2
3use headers::{authorization::Bearer, Authorization, HeaderMapExt};
4use http::HeaderMap;
5use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
6use reqwest::{Client, Url};
7use serde::de::DeserializeOwned;
8
9use crate::{
10    error::{AuthError, InitError},
11    jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
12    layer::{self, AuthorizationLayer, JwtSource},
13    oidc, Refresh, RegisteredClaims,
14};
15
16pub type ClaimsCheckerFn<C> = Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>;
17
18pub struct Authorizer<C = RegisteredClaims>
19where
20    C: Clone + Send,
21{
22    pub key_source: KeySource,
23    pub claims_checker: Option<ClaimsCheckerFn<C>>,
24    pub validation: crate::validation::Validation,
25    pub jwt_source: JwtSource,
26}
27
28fn read_data(path: &str) -> Result<Vec<u8>, InitError> {
29    let mut data = Vec::<u8>::new();
30    let mut f = std::fs::File::open(path)?;
31    f.read_to_end(&mut data)?;
32    Ok(data)
33}
34
35pub enum KeySourceType {
36    RSA(String),
37    RSAString(String),
38    EC(String),
39    ECString(String),
40    ED(String),
41    EDString(String),
42    Secret(String),
43    Jwks(String),
44    JwksPath(String),
45    JwksString(String), // TODO: expose JwksString in JwtAuthorizer or remove it
46    Discovery(String),
47}
48
49impl<C> Authorizer<C>
50where
51    C: DeserializeOwned + Clone + Send,
52{
53    pub(crate) async fn build(
54        key_source_type: KeySourceType,
55        claims_checker: Option<ClaimsCheckerFn<C>>,
56        refresh: Option<Refresh>,
57        validation: crate::validation::Validation,
58        jwt_source: JwtSource,
59        http_client: Option<Client>,
60    ) -> Result<Authorizer<C>, InitError> {
61        Ok(match key_source_type {
62            KeySourceType::RSA(path) => {
63                let key = DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?;
64                Authorizer {
65                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
66                        kid: None,
67                        algs: vec![
68                            Algorithm::RS256,
69                            Algorithm::RS384,
70                            Algorithm::RS512,
71                            Algorithm::PS256,
72                            Algorithm::PS384,
73                            Algorithm::PS512,
74                        ],
75                        key,
76                    })),
77                    claims_checker,
78                    validation,
79                    jwt_source,
80                }
81            }
82            KeySourceType::RSAString(text) => {
83                let key = DecodingKey::from_rsa_pem(text.as_bytes())?;
84                Authorizer {
85                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
86                        kid: None,
87                        algs: vec![
88                            Algorithm::RS256,
89                            Algorithm::RS384,
90                            Algorithm::RS512,
91                            Algorithm::PS256,
92                            Algorithm::PS384,
93                            Algorithm::PS512,
94                        ],
95                        key,
96                    })),
97                    claims_checker,
98                    validation,
99                    jwt_source,
100                }
101            }
102            KeySourceType::EC(path) => {
103                let key = DecodingKey::from_ec_pem(&read_data(path.as_str())?)?;
104                Authorizer {
105                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
106                        kid: None,
107                        algs: vec![Algorithm::ES256, Algorithm::ES384],
108                        key,
109                    })),
110                    claims_checker,
111                    validation,
112                    jwt_source,
113                }
114            }
115            KeySourceType::ECString(text) => {
116                let key = DecodingKey::from_ec_pem(text.as_bytes())?;
117                Authorizer {
118                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
119                        kid: None,
120                        algs: vec![Algorithm::ES256, Algorithm::ES384],
121                        key,
122                    })),
123                    claims_checker,
124                    validation,
125                    jwt_source,
126                }
127            }
128            KeySourceType::ED(path) => {
129                let key = DecodingKey::from_ed_pem(&read_data(path.as_str())?)?;
130                Authorizer {
131                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
132                        kid: None,
133                        algs: vec![Algorithm::EdDSA],
134                        key,
135                    })),
136                    claims_checker,
137                    validation,
138                    jwt_source,
139                }
140            }
141            KeySourceType::EDString(text) => {
142                let key = DecodingKey::from_ed_pem(text.as_bytes())?;
143                Authorizer {
144                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
145                        kid: None,
146                        algs: vec![Algorithm::EdDSA],
147                        key,
148                    })),
149                    claims_checker,
150                    validation,
151                    jwt_source,
152                }
153            }
154            KeySourceType::Secret(secret) => {
155                let key = DecodingKey::from_secret(secret.as_bytes());
156                Authorizer {
157                    key_source: KeySource::SingleKeySource(Arc::new(KeyData {
158                        kid: None,
159                        algs: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
160                        key,
161                    })),
162                    claims_checker,
163                    validation,
164                    jwt_source,
165                }
166            }
167            KeySourceType::JwksPath(path) => {
168                let set: JwkSet = serde_json::from_slice(&read_data(path.as_str())?)?;
169                let keys = set
170                    .keys
171                    .iter()
172                    .map(|k| match KeyData::from_jwk(k) {
173                        Ok(kdata) => Ok(Arc::new(kdata)),
174                        Err(err) => Err(InitError::KeyDecodingError(err)),
175                    })
176                    .collect::<Result<Vec<_>, _>>()?;
177                Authorizer {
178                    key_source: KeySource::MultiKeySource(keys.into()),
179                    claims_checker,
180                    validation,
181                    jwt_source,
182                }
183            }
184            KeySourceType::JwksString(jwks_str) => {
185                // TODO: expose it in JwtAuthorizer or remove
186                let set: JwkSet = serde_json::from_str(jwks_str.as_str())?;
187                let keys = set
188                    .keys
189                    .iter()
190                    .map(|k| match KeyData::from_jwk(k) {
191                        Ok(kdata) => Ok(Arc::new(kdata)),
192                        Err(err) => Err(InitError::KeyDecodingError(err)),
193                    })
194                    .collect::<Result<Vec<_>, _>>()?;
195                Authorizer {
196                    key_source: KeySource::MultiKeySource(keys.into()),
197                    claims_checker,
198                    validation,
199                    jwt_source,
200                }
201            }
202            KeySourceType::Jwks(url) => {
203                let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?;
204                let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
205                Authorizer {
206                    key_source: KeySource::KeyStoreSource(key_store_manager),
207                    claims_checker,
208                    validation,
209                    jwt_source,
210                }
211            }
212            KeySourceType::Discovery(issuer_url) => {
213                let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?)
214                    .map_err(|e| InitError::JwksUrlError(e.to_string()))?;
215
216                let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
217                Authorizer {
218                    key_source: KeySource::KeyStoreSource(key_store_manager),
219                    claims_checker,
220                    validation,
221                    jwt_source,
222                }
223            }
224        })
225    }
226
227    pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> {
228        let header = decode_header(token)?;
229        // TODO: (optimisation) build & store jwt_validation in key data, to avoid rebuilding it for each check
230        let val_key = self.key_source.get_key(header).await?;
231        let jwt_validation = &self.validation.to_jwt_validation(&val_key.algs);
232        let token_data = decode::<C>(token, &val_key.key, jwt_validation)?;
233
234        if let Some(ref checker) = self.claims_checker {
235            if !checker(&token_data.claims) {
236                return Err(AuthError::InvalidClaims());
237            }
238        }
239
240        Ok(token_data)
241    }
242
243    pub fn extract_token(&self, h: &HeaderMap) -> Option<String> {
244        match &self.jwt_source {
245            layer::JwtSource::AuthorizationHeader => {
246                let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
247                bearer_o.map(|b| String::from(b.0.token()))
248            }
249            layer::JwtSource::Cookie(name) => h
250                .typed_get::<headers::Cookie>()
251                .and_then(|c| c.get(name.as_str()).map(String::from)),
252        }
253    }
254}
255
256pub trait IntoLayer<C>
257where
258    C: Clone + DeserializeOwned + Send,
259{
260    fn into_layer(self) -> AuthorizationLayer<C>;
261}
262
263impl<C> IntoLayer<C> for Vec<Authorizer<C>>
264where
265    C: Clone + DeserializeOwned + Send,
266{
267    fn into_layer(self) -> AuthorizationLayer<C> {
268        AuthorizationLayer::new(self.into_iter().map(Arc::new).collect())
269    }
270}
271
272impl<C> IntoLayer<C> for Vec<Arc<Authorizer<C>>>
273where
274    C: Clone + DeserializeOwned + Send,
275{
276    fn into_layer(self) -> AuthorizationLayer<C> {
277        AuthorizationLayer::new(self.into_iter().collect())
278    }
279}
280
281impl<C, const N: usize> IntoLayer<C> for [Authorizer<C>; N]
282where
283    C: Clone + DeserializeOwned + Send,
284{
285    fn into_layer(self) -> AuthorizationLayer<C> {
286        AuthorizationLayer::new(self.into_iter().map(Arc::new).collect())
287    }
288}
289
290impl<C, const N: usize> IntoLayer<C> for [Arc<Authorizer<C>>; N]
291where
292    C: Clone + DeserializeOwned + Send,
293{
294    fn into_layer(self) -> AuthorizationLayer<C> {
295        AuthorizationLayer::new(self.into_iter().collect())
296    }
297}
298
299impl<C> IntoLayer<C> for Authorizer<C>
300where
301    C: Clone + DeserializeOwned + Send,
302{
303    fn into_layer(self) -> AuthorizationLayer<C> {
304        AuthorizationLayer::new(vec![Arc::new(self)])
305    }
306}
307
308impl<C> IntoLayer<C> for Arc<Authorizer<C>>
309where
310    C: Clone + DeserializeOwned + Send,
311{
312    fn into_layer(self) -> AuthorizationLayer<C> {
313        AuthorizationLayer::new(vec![self])
314    }
315}
316
317#[cfg(test)]
318mod tests {
319
320    use jsonwebtoken::{Algorithm, Header};
321    use serde_json::Value;
322
323    use crate::{layer::JwtSource, validation::Validation};
324
325    use super::{Authorizer, KeySourceType};
326
327    #[tokio::test]
328    async fn build_from_secret() {
329        let h = Header::new(Algorithm::HS256);
330        let a = Authorizer::<Value>::build(
331            KeySourceType::Secret("xxxxxx".to_owned()),
332            None,
333            None,
334            Validation::new(),
335            JwtSource::AuthorizationHeader,
336            None,
337        )
338        .await
339        .unwrap();
340        let k = a.key_source.get_key(h);
341        assert!(k.await.is_ok());
342    }
343
344    #[tokio::test]
345    async fn build_from_jwks_string() {
346        let jwks = r#"
347                {"keys": [{
348                    "kid": "1",
349                    "kty": "RSA",
350                    "alg": "RS256",
351                    "use": "sig",
352                    "n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s",
353                    "e": "AQAB"
354                }]}
355        "#;
356        let a = Authorizer::<Value>::build(
357            KeySourceType::JwksString(jwks.to_owned()),
358            None,
359            None,
360            Validation::new(),
361            JwtSource::AuthorizationHeader,
362            None,
363        )
364        .await
365        .unwrap();
366        let k = a.key_source.get_key(Header::new(Algorithm::RS256));
367        assert!(k.await.is_ok());
368    }
369
370    #[tokio::test]
371    async fn build_from_file() {
372        let a = Authorizer::<Value>::build(
373            KeySourceType::RSA("../config/rsa-public1.pem".to_owned()),
374            None,
375            None,
376            Validation::new(),
377            JwtSource::AuthorizationHeader,
378            None,
379        )
380        .await
381        .unwrap();
382        let k = a.key_source.get_key(Header::new(Algorithm::RS256));
383        assert!(k.await.is_ok());
384
385        let a = Authorizer::<Value>::build(
386            KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()),
387            None,
388            None,
389            Validation::new(),
390            JwtSource::AuthorizationHeader,
391            None,
392        )
393        .await
394        .unwrap();
395        let k = a.key_source.get_key(Header::new(Algorithm::ES256));
396        assert!(k.await.is_ok());
397
398        let a = Authorizer::<Value>::build(
399            KeySourceType::ED("../config/ed25519-public1.pem".to_owned()),
400            None,
401            None,
402            Validation::new(),
403            JwtSource::AuthorizationHeader,
404            None,
405        )
406        .await
407        .unwrap();
408        let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
409        assert!(k.await.is_ok());
410
411        let a = Authorizer::<Value>::build(
412            KeySourceType::JwksPath("../config/public1.jwks".to_owned()),
413            None,
414            None,
415            Validation::new(),
416            JwtSource::AuthorizationHeader,
417            None,
418        )
419        .await
420        .unwrap();
421        a.key_source
422            .get_key(Header::new(Algorithm::RS256))
423            .await
424            .expect("Couldn't get RS256 key from jwk");
425        a.key_source
426            .get_key(Header::new(Algorithm::ES256))
427            .await
428            .expect("Couldn't get ES256 key from jwk");
429        a.key_source
430            .get_key(Header::new(Algorithm::EdDSA))
431            .await
432            .expect("Couldn't get EdDSA key from jwk");
433    }
434
435    #[tokio::test]
436    async fn build_from_text() {
437        let a = Authorizer::<Value>::build(
438            KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()),
439            None,
440            None,
441            Validation::new(),
442            JwtSource::AuthorizationHeader,
443            None,
444        )
445        .await
446        .unwrap();
447        let k = a.key_source.get_key(Header::new(Algorithm::RS256));
448        assert!(k.await.is_ok());
449
450        let a = Authorizer::<Value>::build(
451            KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()),
452            None,
453            None,
454            Validation::new(),
455            JwtSource::AuthorizationHeader,
456            None,
457        )
458        .await
459        .unwrap();
460        let k = a.key_source.get_key(Header::new(Algorithm::ES256));
461        assert!(k.await.is_ok());
462
463        let a = Authorizer::<Value>::build(
464            KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()),
465            None,
466            None,
467            Validation::new(),
468            JwtSource::AuthorizationHeader,
469            None,
470        )
471        .await
472        .unwrap();
473        let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
474        assert!(k.await.is_ok());
475    }
476
477    #[tokio::test]
478    async fn build_file_errors() {
479        let a = Authorizer::<Value>::build(
480            KeySourceType::RSA("./config/does-not-exist.pem".to_owned()),
481            None,
482            None,
483            Validation::new(),
484            JwtSource::AuthorizationHeader,
485            None,
486        )
487        .await;
488        println!("{:?}", a.as_ref().err());
489        assert!(a.is_err());
490    }
491
492    #[tokio::test]
493    async fn build_jwks_url_error() {
494        let a = Authorizer::<Value>::build(
495            KeySourceType::Jwks("://xxxx".to_owned()),
496            None,
497            None,
498            Validation::default(),
499            JwtSource::AuthorizationHeader,
500            None,
501        )
502        .await;
503        println!("{:?}", a.as_ref().err());
504        assert!(a.is_err());
505    }
506
507    #[tokio::test]
508    async fn build_discovery_url_error() {
509        let a = Authorizer::<Value>::build(
510            KeySourceType::Discovery("://xxxx".to_owned()),
511            None,
512            None,
513            Validation::default(),
514            JwtSource::AuthorizationHeader,
515            None,
516        )
517        .await;
518        println!("{:?}", a.as_ref().err());
519        assert!(a.is_err());
520    }
521}