1use std::{io::Read, sync::Arc};
2
3use headers::{authorization::Bearer, Authorization, HeaderMapExt};
4use http::HeaderMap;
5use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, TokenData};
6use reqwest::{Client, Url};
7use serde::de::DeserializeOwned;
8
9use crate::{
10 error::{AuthError, InitError},
11 jwks::{key_store_manager::KeyStoreManager, KeyData, KeySource},
12 layer::{self, AuthorizationLayer, JwtSource},
13 oidc, Refresh, RegisteredClaims,
14};
15
16pub type ClaimsCheckerFn<C> = Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>;
17
18pub struct Authorizer<C = RegisteredClaims>
19where
20 C: Clone + Send,
21{
22 pub key_source: KeySource,
23 pub claims_checker: Option<ClaimsCheckerFn<C>>,
24 pub validation: crate::validation::Validation,
25 pub jwt_source: JwtSource,
26}
27
28fn read_data(path: &str) -> Result<Vec<u8>, InitError> {
29 let mut data = Vec::<u8>::new();
30 let mut f = std::fs::File::open(path)?;
31 f.read_to_end(&mut data)?;
32 Ok(data)
33}
34
35pub enum KeySourceType {
36 RSA(String),
37 RSAString(String),
38 EC(String),
39 ECString(String),
40 ED(String),
41 EDString(String),
42 Secret(String),
43 Jwks(String),
44 JwksPath(String),
45 JwksString(String), Discovery(String),
47}
48
49impl<C> Authorizer<C>
50where
51 C: DeserializeOwned + Clone + Send,
52{
53 pub(crate) async fn build(
54 key_source_type: KeySourceType,
55 claims_checker: Option<ClaimsCheckerFn<C>>,
56 refresh: Option<Refresh>,
57 validation: crate::validation::Validation,
58 jwt_source: JwtSource,
59 http_client: Option<Client>,
60 ) -> Result<Authorizer<C>, InitError> {
61 Ok(match key_source_type {
62 KeySourceType::RSA(path) => {
63 let key = DecodingKey::from_rsa_pem(&read_data(path.as_str())?)?;
64 Authorizer {
65 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
66 kid: None,
67 algs: vec![
68 Algorithm::RS256,
69 Algorithm::RS384,
70 Algorithm::RS512,
71 Algorithm::PS256,
72 Algorithm::PS384,
73 Algorithm::PS512,
74 ],
75 key,
76 })),
77 claims_checker,
78 validation,
79 jwt_source,
80 }
81 }
82 KeySourceType::RSAString(text) => {
83 let key = DecodingKey::from_rsa_pem(text.as_bytes())?;
84 Authorizer {
85 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
86 kid: None,
87 algs: vec![
88 Algorithm::RS256,
89 Algorithm::RS384,
90 Algorithm::RS512,
91 Algorithm::PS256,
92 Algorithm::PS384,
93 Algorithm::PS512,
94 ],
95 key,
96 })),
97 claims_checker,
98 validation,
99 jwt_source,
100 }
101 }
102 KeySourceType::EC(path) => {
103 let key = DecodingKey::from_ec_pem(&read_data(path.as_str())?)?;
104 Authorizer {
105 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
106 kid: None,
107 algs: vec![Algorithm::ES256, Algorithm::ES384],
108 key,
109 })),
110 claims_checker,
111 validation,
112 jwt_source,
113 }
114 }
115 KeySourceType::ECString(text) => {
116 let key = DecodingKey::from_ec_pem(text.as_bytes())?;
117 Authorizer {
118 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
119 kid: None,
120 algs: vec![Algorithm::ES256, Algorithm::ES384],
121 key,
122 })),
123 claims_checker,
124 validation,
125 jwt_source,
126 }
127 }
128 KeySourceType::ED(path) => {
129 let key = DecodingKey::from_ed_pem(&read_data(path.as_str())?)?;
130 Authorizer {
131 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
132 kid: None,
133 algs: vec![Algorithm::EdDSA],
134 key,
135 })),
136 claims_checker,
137 validation,
138 jwt_source,
139 }
140 }
141 KeySourceType::EDString(text) => {
142 let key = DecodingKey::from_ed_pem(text.as_bytes())?;
143 Authorizer {
144 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
145 kid: None,
146 algs: vec![Algorithm::EdDSA],
147 key,
148 })),
149 claims_checker,
150 validation,
151 jwt_source,
152 }
153 }
154 KeySourceType::Secret(secret) => {
155 let key = DecodingKey::from_secret(secret.as_bytes());
156 Authorizer {
157 key_source: KeySource::SingleKeySource(Arc::new(KeyData {
158 kid: None,
159 algs: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
160 key,
161 })),
162 claims_checker,
163 validation,
164 jwt_source,
165 }
166 }
167 KeySourceType::JwksPath(path) => {
168 let set: JwkSet = serde_json::from_slice(&read_data(path.as_str())?)?;
169 let keys = set
170 .keys
171 .iter()
172 .map(|k| match KeyData::from_jwk(k) {
173 Ok(kdata) => Ok(Arc::new(kdata)),
174 Err(err) => Err(InitError::KeyDecodingError(err)),
175 })
176 .collect::<Result<Vec<_>, _>>()?;
177 Authorizer {
178 key_source: KeySource::MultiKeySource(keys.into()),
179 claims_checker,
180 validation,
181 jwt_source,
182 }
183 }
184 KeySourceType::JwksString(jwks_str) => {
185 let set: JwkSet = serde_json::from_str(jwks_str.as_str())?;
187 let keys = set
188 .keys
189 .iter()
190 .map(|k| match KeyData::from_jwk(k) {
191 Ok(kdata) => Ok(Arc::new(kdata)),
192 Err(err) => Err(InitError::KeyDecodingError(err)),
193 })
194 .collect::<Result<Vec<_>, _>>()?;
195 Authorizer {
196 key_source: KeySource::MultiKeySource(keys.into()),
197 claims_checker,
198 validation,
199 jwt_source,
200 }
201 }
202 KeySourceType::Jwks(url) => {
203 let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?;
204 let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
205 Authorizer {
206 key_source: KeySource::KeyStoreSource(key_store_manager),
207 claims_checker,
208 validation,
209 jwt_source,
210 }
211 }
212 KeySourceType::Discovery(issuer_url) => {
213 let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?)
214 .map_err(|e| InitError::JwksUrlError(e.to_string()))?;
215
216 let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default());
217 Authorizer {
218 key_source: KeySource::KeyStoreSource(key_store_manager),
219 claims_checker,
220 validation,
221 jwt_source,
222 }
223 }
224 })
225 }
226
227 pub async fn check_auth(&self, token: &str) -> Result<TokenData<C>, AuthError> {
228 let header = decode_header(token)?;
229 let val_key = self.key_source.get_key(header).await?;
231 let jwt_validation = &self.validation.to_jwt_validation(&val_key.algs);
232 let token_data = decode::<C>(token, &val_key.key, jwt_validation)?;
233
234 if let Some(ref checker) = self.claims_checker {
235 if !checker(&token_data.claims) {
236 return Err(AuthError::InvalidClaims());
237 }
238 }
239
240 Ok(token_data)
241 }
242
243 pub fn extract_token(&self, h: &HeaderMap) -> Option<String> {
244 match &self.jwt_source {
245 layer::JwtSource::AuthorizationHeader => {
246 let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
247 bearer_o.map(|b| String::from(b.0.token()))
248 }
249 layer::JwtSource::Cookie(name) => h
250 .typed_get::<headers::Cookie>()
251 .and_then(|c| c.get(name.as_str()).map(String::from)),
252 }
253 }
254}
255
256pub trait IntoLayer<C>
257where
258 C: Clone + DeserializeOwned + Send,
259{
260 fn into_layer(self) -> AuthorizationLayer<C>;
261}
262
263impl<C> IntoLayer<C> for Vec<Authorizer<C>>
264where
265 C: Clone + DeserializeOwned + Send,
266{
267 fn into_layer(self) -> AuthorizationLayer<C> {
268 AuthorizationLayer::new(self.into_iter().map(Arc::new).collect())
269 }
270}
271
272impl<C> IntoLayer<C> for Vec<Arc<Authorizer<C>>>
273where
274 C: Clone + DeserializeOwned + Send,
275{
276 fn into_layer(self) -> AuthorizationLayer<C> {
277 AuthorizationLayer::new(self.into_iter().collect())
278 }
279}
280
281impl<C, const N: usize> IntoLayer<C> for [Authorizer<C>; N]
282where
283 C: Clone + DeserializeOwned + Send,
284{
285 fn into_layer(self) -> AuthorizationLayer<C> {
286 AuthorizationLayer::new(self.into_iter().map(Arc::new).collect())
287 }
288}
289
290impl<C, const N: usize> IntoLayer<C> for [Arc<Authorizer<C>>; N]
291where
292 C: Clone + DeserializeOwned + Send,
293{
294 fn into_layer(self) -> AuthorizationLayer<C> {
295 AuthorizationLayer::new(self.into_iter().collect())
296 }
297}
298
299impl<C> IntoLayer<C> for Authorizer<C>
300where
301 C: Clone + DeserializeOwned + Send,
302{
303 fn into_layer(self) -> AuthorizationLayer<C> {
304 AuthorizationLayer::new(vec![Arc::new(self)])
305 }
306}
307
308impl<C> IntoLayer<C> for Arc<Authorizer<C>>
309where
310 C: Clone + DeserializeOwned + Send,
311{
312 fn into_layer(self) -> AuthorizationLayer<C> {
313 AuthorizationLayer::new(vec![self])
314 }
315}
316
317#[cfg(test)]
318mod tests {
319
320 use jsonwebtoken::{Algorithm, Header};
321 use serde_json::Value;
322
323 use crate::{layer::JwtSource, validation::Validation};
324
325 use super::{Authorizer, KeySourceType};
326
327 #[tokio::test]
328 async fn build_from_secret() {
329 let h = Header::new(Algorithm::HS256);
330 let a = Authorizer::<Value>::build(
331 KeySourceType::Secret("xxxxxx".to_owned()),
332 None,
333 None,
334 Validation::new(),
335 JwtSource::AuthorizationHeader,
336 None,
337 )
338 .await
339 .unwrap();
340 let k = a.key_source.get_key(h);
341 assert!(k.await.is_ok());
342 }
343
344 #[tokio::test]
345 async fn build_from_jwks_string() {
346 let jwks = r#"
347 {"keys": [{
348 "kid": "1",
349 "kty": "RSA",
350 "alg": "RS256",
351 "use": "sig",
352 "n": "2pQeZdxa7q093K7bj5h6-leIpxfTnuAxzXdhjfGEJHxmt2ekHyCBWWWXCBiDn2RTcEBcy6gZqOW45Uy_tw-5e-Px1xFj1PykGEkRlOpYSAeWsNaAWvvpGB9m4zQ0PgZeMDDXE5IIBrY6YAzmGQxV-fcGGLhJnXl0-5_z7tKC7RvBoT3SGwlc_AmJqpFtTpEBn_fDnyqiZbpcjXYLExFpExm41xDitRKHWIwfc3dV8_vlNntlxCPGy_THkjdXJoHv2IJmlhvmr5_h03iGMLWDKSywxOol_4Wc1BT7Hb6byMxW40GKwSJJ4p7W8eI5mqggRHc8jlwSsTN9LZ2VOvO-XiVShZRVg7JeraGAfWwaIgIJ1D8C1h5Pi0iFpp2suxpHAXHfyLMJXuVotpXbDh4NDX-A4KRMgaxcfAcui_x6gybksq6gF90-9nfQfmVMVJctZ6M-FvRr-itd1Nef5WAtwUp1qyZygAXU3cH3rarscajmurOsP6dE1OHl3grY_eZhQxk33VBK9lavqNKPg6Q_PLiq1ojbYBj3bcYifJrsNeQwxldQP83aWt5rGtgZTehKVJwa40Uy_Grae1iRnsDtdSy5sTJIJ6EiShnWAdMoGejdiI8vpkjrdU8SWH8lv1KXI54DsbyAuke2cYz02zPWc6JEotQqI0HwhzU0KHyoY4s",
353 "e": "AQAB"
354 }]}
355 "#;
356 let a = Authorizer::<Value>::build(
357 KeySourceType::JwksString(jwks.to_owned()),
358 None,
359 None,
360 Validation::new(),
361 JwtSource::AuthorizationHeader,
362 None,
363 )
364 .await
365 .unwrap();
366 let k = a.key_source.get_key(Header::new(Algorithm::RS256));
367 assert!(k.await.is_ok());
368 }
369
370 #[tokio::test]
371 async fn build_from_file() {
372 let a = Authorizer::<Value>::build(
373 KeySourceType::RSA("../config/rsa-public1.pem".to_owned()),
374 None,
375 None,
376 Validation::new(),
377 JwtSource::AuthorizationHeader,
378 None,
379 )
380 .await
381 .unwrap();
382 let k = a.key_source.get_key(Header::new(Algorithm::RS256));
383 assert!(k.await.is_ok());
384
385 let a = Authorizer::<Value>::build(
386 KeySourceType::EC("../config/ecdsa-public1.pem".to_owned()),
387 None,
388 None,
389 Validation::new(),
390 JwtSource::AuthorizationHeader,
391 None,
392 )
393 .await
394 .unwrap();
395 let k = a.key_source.get_key(Header::new(Algorithm::ES256));
396 assert!(k.await.is_ok());
397
398 let a = Authorizer::<Value>::build(
399 KeySourceType::ED("../config/ed25519-public1.pem".to_owned()),
400 None,
401 None,
402 Validation::new(),
403 JwtSource::AuthorizationHeader,
404 None,
405 )
406 .await
407 .unwrap();
408 let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
409 assert!(k.await.is_ok());
410
411 let a = Authorizer::<Value>::build(
412 KeySourceType::JwksPath("../config/public1.jwks".to_owned()),
413 None,
414 None,
415 Validation::new(),
416 JwtSource::AuthorizationHeader,
417 None,
418 )
419 .await
420 .unwrap();
421 a.key_source
422 .get_key(Header::new(Algorithm::RS256))
423 .await
424 .expect("Couldn't get RS256 key from jwk");
425 a.key_source
426 .get_key(Header::new(Algorithm::ES256))
427 .await
428 .expect("Couldn't get ES256 key from jwk");
429 a.key_source
430 .get_key(Header::new(Algorithm::EdDSA))
431 .await
432 .expect("Couldn't get EdDSA key from jwk");
433 }
434
435 #[tokio::test]
436 async fn build_from_text() {
437 let a = Authorizer::<Value>::build(
438 KeySourceType::RSAString(include_str!("../../config/rsa-public1.pem").to_owned()),
439 None,
440 None,
441 Validation::new(),
442 JwtSource::AuthorizationHeader,
443 None,
444 )
445 .await
446 .unwrap();
447 let k = a.key_source.get_key(Header::new(Algorithm::RS256));
448 assert!(k.await.is_ok());
449
450 let a = Authorizer::<Value>::build(
451 KeySourceType::ECString(include_str!("../../config/ecdsa-public1.pem").to_owned()),
452 None,
453 None,
454 Validation::new(),
455 JwtSource::AuthorizationHeader,
456 None,
457 )
458 .await
459 .unwrap();
460 let k = a.key_source.get_key(Header::new(Algorithm::ES256));
461 assert!(k.await.is_ok());
462
463 let a = Authorizer::<Value>::build(
464 KeySourceType::EDString(include_str!("../../config/ed25519-public1.pem").to_owned()),
465 None,
466 None,
467 Validation::new(),
468 JwtSource::AuthorizationHeader,
469 None,
470 )
471 .await
472 .unwrap();
473 let k = a.key_source.get_key(Header::new(Algorithm::EdDSA));
474 assert!(k.await.is_ok());
475 }
476
477 #[tokio::test]
478 async fn build_file_errors() {
479 let a = Authorizer::<Value>::build(
480 KeySourceType::RSA("./config/does-not-exist.pem".to_owned()),
481 None,
482 None,
483 Validation::new(),
484 JwtSource::AuthorizationHeader,
485 None,
486 )
487 .await;
488 println!("{:?}", a.as_ref().err());
489 assert!(a.is_err());
490 }
491
492 #[tokio::test]
493 async fn build_jwks_url_error() {
494 let a = Authorizer::<Value>::build(
495 KeySourceType::Jwks("://xxxx".to_owned()),
496 None,
497 None,
498 Validation::default(),
499 JwtSource::AuthorizationHeader,
500 None,
501 )
502 .await;
503 println!("{:?}", a.as_ref().err());
504 assert!(a.is_err());
505 }
506
507 #[tokio::test]
508 async fn build_discovery_url_error() {
509 let a = Authorizer::<Value>::build(
510 KeySourceType::Discovery("://xxxx".to_owned()),
511 None,
512 None,
513 Validation::default(),
514 JwtSource::AuthorizationHeader,
515 None,
516 )
517 .await;
518 println!("{:?}", a.as_ref().err());
519 assert!(a.is_err());
520 }
521}