1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
use async_trait::async_trait; pub use jsonwebtoken::errors::Error as JwtError; pub use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; use serde::de::DeserializeOwned; use std::marker::PhantomData; use salvo_core::Depot; use salvo_core::http::header::AUTHORIZATION; use salvo_core::http::{Request, Response}; use salvo_core::http::errors::*; use salvo_core::Handler; pub struct JwtHandler<C> where C: DeserializeOwned + Sync + Send + 'static, { config: JwtConfig<C>, } pub struct JwtConfig<C> where C: DeserializeOwned + Sync + Send + 'static, { pub secret: String, pub context_token_key: Option<String>, pub context_data_key: Option<String>, pub context_state_key: Option<String>, pub response_error: bool, pub claims: PhantomData<C>, pub validation: Validation, pub extractors: Vec<Box<dyn JwtExtractor>>, } #[async_trait] pub trait JwtExtractor: Send + Sync { async fn get_token(&self, req: &mut Request) -> Option<String>; } #[derive(Default)] pub struct HeaderExtractor; impl HeaderExtractor { pub fn new() -> Self { HeaderExtractor {} } } #[async_trait] impl JwtExtractor for HeaderExtractor { async fn get_token(&self, req: &mut Request) -> Option<String> { if let Some(auth) = req.headers().get(AUTHORIZATION) { if let Ok(auth) = auth.to_str() { if auth.starts_with("Bearer") { return auth.splitn(2, ' ').collect::<Vec<&str>>().pop().map(|s| s.to_owned()); } } } None } } pub struct FormExtractor(String); impl FormExtractor { pub fn new<T: Into<String>>(name: T) -> Self { FormExtractor(name.into()) } } #[async_trait] impl JwtExtractor for FormExtractor { async fn get_token(&self, req: &mut Request) -> Option<String> { req.get_form(&self.0).await } } pub struct QueryExtractor(String); impl QueryExtractor { pub fn new<T: Into<String>>(name: T) -> Self { QueryExtractor(name.into()) } } #[async_trait] impl JwtExtractor for QueryExtractor { async fn get_token(&self, req: &mut Request) -> Option<String> { req.get_query(&self.0) } } pub struct CookieExtractor(String); impl CookieExtractor { pub fn new<T: Into<String>>(name: T) -> Self { CookieExtractor(name.into()) } } #[async_trait] impl JwtExtractor for CookieExtractor { async fn get_token(&self, req: &mut Request) -> Option<String> { req.get_cookie(&self.0).map(|c| c.value().to_owned()) } } impl<C> JwtHandler<C> where C: DeserializeOwned + Sync + Send + 'static, { pub fn new(config: JwtConfig<C>) -> JwtHandler<C> { JwtHandler { config } } pub fn decode(&self, token: &str) -> Result<TokenData<C>, JwtError> { decode::<C>(&token, &DecodingKey::from_secret(&*self.config.secret.as_ref()), &self.config.validation) } } #[async_trait] impl<C> Handler for JwtHandler<C> where C: DeserializeOwned + Sync + Send + 'static, { async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response) { for extractor in &self.config.extractors { if let Some(token) = extractor.get_token(req).await { if let Ok(data) = self.decode(&token) { if let Some(key) = &self.config.context_data_key { depot.insert(key.clone(), data); } if let Some(key) = &self.config.context_state_key { depot.insert(key.clone(), "authorized"); } } else { if let Some(key) = &self.config.context_state_key { depot.insert(key.clone(), "forbidden"); } if self.config.response_error { res.set_http_error(Forbidden()); } } if let Some(key) = &self.config.context_token_key { depot.insert(key.clone(), token); } return; } } if let Some(key) = &self.config.context_state_key { depot.insert(key.clone(), "unauthorized"); } if self.config.response_error { res.set_http_error(Unauthorized()); } } }