salvo_jwt_auth/
lib.rs

1//! Provides JWT (JSON Web Token) authentication support for the Salvo web framework.
2//!
3//! This crate helps you implement JWT-based authentication in your Salvo web applications.
4//! It offers flexible token extraction from various sources (headers, query parameters, cookies,
5//! etc.) and multiple decoding strategies.
6//!
7//! # Features
8//!
9//! - Extract JWT tokens from multiple sources (headers, query parameters, cookies, forms)
10//! - Configurable token validation
11//! - OpenID Connect support (behind the `oidc` feature flag)
12//! - Seamless integration with Salvo's middleware system
13//!
14//! # Example:
15//!
16//! ```no_run
17//! use jsonwebtoken::{self, EncodingKey};
18//! use salvo::http::{Method, StatusError};
19//! use salvo::jwt_auth::{ConstDecoder, QueryFinder};
20//! use salvo::prelude::*;
21//! use serde::{Deserialize, Serialize};
22//! use time::{Duration, OffsetDateTime};
23//!
24//! const SECRET_KEY: &str = "YOUR_SECRET_KEY"; // In production, use a secure key management solution
25//!
26//! #[derive(Serialize, Deserialize, Clone, Debug)]
27//! pub struct JwtClaims {
28//!     username: String,
29//!     exp: i64,
30//! }
31//!
32//! #[tokio::main]
33//! async fn main() {
34//!     let auth_handler: JwtAuth<JwtClaims, _> = JwtAuth::new(ConstDecoder::from_secret(SECRET_KEY.as_bytes()))
35//!         .finders(vec![
36//!             // Box::new(HeaderFinder::new()),
37//!             Box::new(QueryFinder::new("jwt_token")),
38//!             // Box::new(CookieFinder::new("jwt_token")),
39//!         ])
40//!         .force_passed(true);
41//!
42//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
43//!     Server::new(acceptor)
44//!         .serve(Router::with_hoop(auth_handler).goal(index))
45//!         .await;
46//! }
47//! #[handler]
48//! async fn index(req: &mut Request, depot: &mut Depot, res: &mut Response) -> anyhow::Result<()> {
49//!     if req.method() == Method::POST {
50//!         let (username, password) = (
51//!             req.form::<String>("username").await.unwrap_or_default(),
52//!             req.form::<String>("password").await.unwrap_or_default(),
53//!         );
54//!         if !validate(&username, &password) {
55//!             res.render(Text::Html(LOGIN_HTML));
56//!             return Ok(());
57//!         }
58//!         let exp = OffsetDateTime::now_utc() + Duration::days(14);
59//!         let claim = JwtClaims {
60//!             username,
61//!             exp: exp.unix_timestamp(),
62//!         };
63//!         let token = jsonwebtoken::encode(
64//!             &jsonwebtoken::Header::default(),
65//!             &claim,
66//!             &EncodingKey::from_secret(SECRET_KEY.as_bytes()),
67//!         )?;
68//!         res.render(Redirect::other(format!("/?jwt_token={token}")));
69//!     } else {
70//!         match depot.jwt_auth_state() {
71//!             JwtAuthState::Authorized => {
72//!                 let data = depot.jwt_auth_data::<JwtClaims>().unwrap();
73//!                 res.render(Text::Plain(format!(
74//!                     "Hi {}, you have logged in successfully!",
75//!                     data.claims.username
76//!                 )));
77//!             }
78//!             JwtAuthState::Unauthorized => {
79//!                 res.render(Text::Html(LOGIN_HTML));
80//!             }
81//!             JwtAuthState::Forbidden => {
82//!                 res.render(StatusError::forbidden());
83//!             }
84//!         }
85//!     }
86//!     Ok(())
87//! }
88//!
89//! fn validate(username: &str, password: &str) -> bool {
90//!     // In a real application, use secure password verification
91//!     username == "root" && password == "pwd"
92//! }
93//!
94//! static LOGIN_HTML: &str = r#"<!DOCTYPE html>
95//! <html>
96//!     <head>
97//!         <title>JWT Auth Demo</title>
98//!     </head>
99//!     <body>
100//!         <h1>JWT Auth</h1>
101//!         <form action="/" method="post">
102//!         <label for="username"><b>Username</b></label>
103//!         <input type="text" placeholder="Enter Username" name="username" required>
104//!
105//!         <label for="password"><b>Password</b></label>
106//!         <input type="password" placeholder="Enter Password" name="password" required>
107//!
108//!         <button type="submit">Login</button>
109//!     </form>
110//!     </body>
111//! </html>
112//! "#;
113//! ```
114
115#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
116#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
117#![cfg_attr(docsrs, feature(doc_cfg))]
118
119use std::fmt::{self, Debug, Formatter};
120use std::marker::PhantomData;
121
122#[doc(no_inline)]
123pub use jsonwebtoken::{
124    Algorithm, DecodingKey, TokenData, Validation, decode, errors::Error as JwtError,
125};
126use salvo_core::http::{Method, Request, Response, StatusError};
127use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
128use serde::de::DeserializeOwned;
129use thiserror::Error;
130
131mod finder;
132pub use finder::{CookieFinder, FormFinder, HeaderFinder, JwtTokenFinder, QueryFinder};
133
134mod decoder;
135pub use decoder::{ConstDecoder, JwtAuthDecoder};
136
137#[macro_use]
138mod cfg;
139
140cfg_feature! {
141    #![feature = "oidc"]
142    pub mod oidc;
143    pub use oidc::OidcDecoder;
144}
145
146/// key used to insert auth decoded data to depot.
147pub const JWT_AUTH_DATA_KEY: &str = "::salvo::jwt_auth::auth_data";
148/// key used to insert auth state data to depot.
149pub const JWT_AUTH_STATE_KEY: &str = "::salvo::jwt_auth::auth_state";
150/// key used to insert auth token data to depot.
151pub const JWT_AUTH_TOKEN_KEY: &str = "::salvo::jwt_auth::auth_token";
152/// key used to insert auth error to depot.
153pub const JWT_AUTH_ERROR_KEY: &str = "::salvo::jwt_auth::auth_error";
154
155const ALL_METHODS: [Method; 9] = [
156    Method::GET,
157    Method::POST,
158    Method::PUT,
159    Method::DELETE,
160    Method::HEAD,
161    Method::OPTIONS,
162    Method::CONNECT,
163    Method::PATCH,
164    Method::TRACE,
165];
166
167/// JwtAuthError
168#[derive(Debug, Error)]
169pub enum JwtAuthError {
170    /// HTTP client error
171    #[cfg(feature = "oidc")]
172    #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))]
173    #[error("ClientError")]
174    ClientError(#[from] hyper_util::client::legacy::Error),
175
176    /// Error occurred in hyper.
177    #[cfg(feature = "oidc")]
178    #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))]
179    #[error("HyperError")]
180    Hyper(#[from] salvo_core::hyper::Error),
181
182    /// InvalidUri
183    #[error("InvalidUri")]
184    InvalidUri(#[from] salvo_core::http::uri::InvalidUri),
185    /// Serde error
186    #[error("Serde error")]
187    SerdeError(#[from] serde_json::Error),
188    /// Failed to discover OIDC configuration
189    #[error("Failed to discover OIDC configuration")]
190    DiscoverError,
191    /// Decoding of JWKS error
192    #[error("Decoding of JWKS error")]
193    DecodeError(#[from] base64::DecodeError),
194    /// JWT is missing kid, alg, or decoding components
195    #[error("JWT is missing kid, alg, or decoding components")]
196    InvalidJwk,
197    /// Issuer URL invalid
198    #[error("Issuer URL invalid")]
199    IssuerParseError,
200    /// Failure of validating the token. See [jsonwebtoken::errors::ErrorKind] for possible reasons
201    /// this value could be returned Would typically result in a 401 HTTP Status code
202    #[error("JWT Is Invalid")]
203    ValidationFailed(#[from] jsonwebtoken::errors::Error),
204    /// Failure to re-validate the JWKS.
205    /// Would typically result in a 401 or 500 status code depending on preference
206    #[error("Token was unable to be validated due to cache expiration")]
207    CacheError,
208    /// Token did not contain a kid in its header and would be impossible to validate
209    /// Would typically result in a 401 HTTP Status code
210    #[error("Token did not contain a KID field")]
211    MissingKid,
212}
213
214/// Possible states of JWT authentication.
215///
216/// The middleware sets this state in the depot after processing a request.
217/// You can access it via `depot.jwt_auth_state()`.
218#[derive(Copy, Clone, Eq, PartialEq, Debug)]
219pub enum JwtAuthState {
220    /// Authentication was successful and the token was valid.
221    Authorized,
222    /// No token was provided in the request.
223    /// Usually results in a 401 Unauthorized response unless `force_passed` is true.
224    Unauthorized,
225    /// A token was provided but it failed validation.
226    /// Usually results in a 403 Forbidden response unless `force_passed` is true.
227    Forbidden,
228}
229
230/// Extension trait for accessing JWT authentication data from the depot.
231///
232/// This trait provides convenient methods to retrieve JWT authentication information
233/// that was previously stored in the depot by the `JwtAuth` middleware.
234pub trait JwtAuthDepotExt {
235    /// Gets the JWT token string from the depot.
236    fn jwt_auth_token(&self) -> Option<&str>;
237
238    /// Gets the decoded JWT claims data from the depot.
239    ///
240    /// The generic parameter `C` should be the same type used when configuring the `JwtAuth`
241    /// middleware.
242    fn jwt_auth_data<C>(&self) -> Option<&TokenData<C>>
243    where
244        C: DeserializeOwned + Send + Sync + 'static;
245
246    /// Gets the current JWT authentication state from the depot.
247    ///
248    /// Returns `JwtAuthState::Unauthorized` if no state is present in the depot.
249    fn jwt_auth_state(&self) -> JwtAuthState;
250
251    /// Gets the JWT error if authentication failed.
252    fn jwt_auth_error(&self) -> Option<&JwtError>;
253}
254
255impl JwtAuthDepotExt for Depot {
256    #[inline]
257    fn jwt_auth_token(&self) -> Option<&str> {
258        self.get::<String>(JWT_AUTH_TOKEN_KEY).map(|v| &**v).ok()
259    }
260
261    #[inline]
262    fn jwt_auth_data<C>(&self) -> Option<&TokenData<C>>
263    where
264        C: DeserializeOwned + Send + Sync + 'static,
265    {
266        self.get(JWT_AUTH_DATA_KEY).ok()
267    }
268
269    #[inline]
270    fn jwt_auth_state(&self) -> JwtAuthState {
271        self.get(JWT_AUTH_STATE_KEY)
272            .ok()
273            .cloned()
274            .unwrap_or(JwtAuthState::Unauthorized)
275    }
276
277    #[inline]
278    fn jwt_auth_error(&self) -> Option<&JwtError> {
279        self.get(JWT_AUTH_ERROR_KEY).ok()
280    }
281}
282
283/// JWT Authentication middleware for Salvo.
284///
285/// `JwtAuth` extracts and validates JWT tokens from incoming requests based on the configured
286/// token finders and decoder. If valid, it stores the decoded data in the depot for later use.
287///
288/// # Type Parameters
289///
290/// * `C` - The claims type that will be deserialized from the JWT payload.
291/// * `D` - The decoder implementation used to validate and decode the JWT token.
292#[non_exhaustive]
293pub struct JwtAuth<C, D> {
294    /// When set to `true`, the middleware will allow the request to proceed even if
295    /// authentication fails, storing only the authentication state in the depot.
296    ///
297    /// When set to `false` (default), requests with invalid or missing tokens will be
298    /// immediately rejected with appropriate status codes.
299    pub force_passed: bool,
300    _claims: PhantomData<C>,
301    /// The decoder used to validate and decode the JWT token.
302    pub decoder: D,
303    /// A list of token finders that will be used to extract the token from the request.
304    /// Finders are tried in order until one returns a token.
305    pub finders: Vec<Box<dyn JwtTokenFinder>>,
306}
307impl<C, D> Debug for JwtAuth<C, D>
308where
309    C: DeserializeOwned + Send + Sync + 'static,
310    D: JwtAuthDecoder + Send + Sync + 'static,
311{
312    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
313        f.debug_struct("JwtAuth")
314            .field("force_passed", &self.force_passed)
315            .finish()
316    }
317}
318
319impl<C, D> JwtAuth<C, D>
320where
321    C: DeserializeOwned + Send + Sync + 'static,
322    D: JwtAuthDecoder + Send + Sync + 'static,
323{
324    /// Create new `JwtAuth`.
325    #[inline]
326    #[must_use]
327    pub fn new(decoder: D) -> Self {
328        Self {
329            force_passed: false,
330            decoder,
331            _claims: PhantomData::<C>,
332            finders: vec![Box::new(HeaderFinder::new())],
333        }
334    }
335    /// Sets force_passed value and return Self.
336    #[inline]
337    #[must_use]
338    pub fn force_passed(mut self, force_passed: bool) -> Self {
339        self.force_passed = force_passed;
340        self
341    }
342
343    /// Get decoder mutable reference.
344    #[inline]
345    pub fn decoder_mut(&mut self) -> &mut D {
346        &mut self.decoder
347    }
348
349    /// Gets a mutable reference to the extractor list.
350    #[inline]
351    pub fn finders_mut(&mut self) -> &mut Vec<Box<dyn JwtTokenFinder>> {
352        &mut self.finders
353    }
354    /// Sets extractor list with new value and return Self.
355    #[inline]
356    #[must_use]
357    pub fn finders(mut self, finders: Vec<Box<dyn JwtTokenFinder>>) -> Self {
358        self.finders = finders;
359        self
360    }
361
362    async fn find_token(&self, req: &mut Request) -> Option<String> {
363        for finder in &self.finders {
364            if let Some(token) = finder.find_token(req).await {
365                return Some(token);
366            }
367        }
368        None
369    }
370}
371
372#[async_trait]
373impl<C, D> Handler for JwtAuth<C, D>
374where
375    C: DeserializeOwned + Clone + Send + Sync + 'static,
376    D: JwtAuthDecoder + Send + Sync + 'static,
377{
378    async fn handle(
379        &self,
380        req: &mut Request,
381        depot: &mut Depot,
382        res: &mut Response,
383        ctrl: &mut FlowCtrl,
384    ) {
385        let token = self.find_token(req).await;
386        if let Some(token) = token {
387            match self.decoder.decode::<C>(&token, depot).await {
388                Ok(data) => {
389                    depot.insert(JWT_AUTH_DATA_KEY, data);
390                    depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Authorized);
391                    depot.insert(JWT_AUTH_TOKEN_KEY, token);
392                }
393                Err(e) => {
394                    tracing::info!(error = ?e, "jwt auth error");
395                    depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Forbidden);
396                    depot.insert(JWT_AUTH_ERROR_KEY, e);
397                    if !self.force_passed {
398                        res.render(StatusError::forbidden());
399                        ctrl.skip_rest();
400                    }
401                }
402            }
403        } else {
404            depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Unauthorized);
405            if !self.force_passed {
406                res.render(StatusError::unauthorized());
407                ctrl.skip_rest();
408            }
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use jsonwebtoken::EncodingKey;
416    use salvo_core::prelude::*;
417    use salvo_core::test::{ResponseExt, TestClient};
418    use serde::{Deserialize, Serialize};
419    use time::{Duration, OffsetDateTime};
420
421    use super::*;
422
423    #[derive(Serialize, Deserialize, Clone, Debug)]
424    struct JwtClaims {
425        user: String,
426        exp: i64,
427    }
428    #[tokio::test]
429    async fn test_jwt_auth() {
430        let auth_handler: JwtAuth<JwtClaims, ConstDecoder> =
431            JwtAuth::new(ConstDecoder::from_secret(b"ABCDEF")).finders(vec![
432                Box::new(HeaderFinder::new()),
433                Box::new(QueryFinder::new("jwt_token")),
434                Box::new(CookieFinder::new("jwt_token")),
435            ]);
436
437        #[handler]
438        async fn hello() -> &'static str {
439            "hello"
440        }
441
442        let router = Router::new()
443            .hoop(auth_handler)
444            .push(Router::with_path("hello").get(hello));
445        let service = Service::new(router);
446
447        async fn access(service: &Service, token: &str) -> String {
448            TestClient::get("http://127.0.0.1:5801/hello")
449                .add_header("Authorization", format!("Bearer {token}"), true)
450                .send(service)
451                .await
452                .take_string()
453                .await
454                .unwrap()
455        }
456
457        let claim = JwtClaims {
458            user: "root".into(),
459            exp: (OffsetDateTime::now_utc() + Duration::days(1)).unix_timestamp(),
460        };
461
462        let token = jsonwebtoken::encode(
463            &jsonwebtoken::Header::default(),
464            &claim,
465            &EncodingKey::from_secret(b"ABCDEF"),
466        )
467        .unwrap();
468        let content = access(&service, &token).await;
469        assert!(content.contains("hello"));
470
471        let content = TestClient::get(format!("http://127.0.0.1:5801/hello?jwt_token={token}"))
472            .send(&service)
473            .await
474            .take_string()
475            .await
476            .unwrap();
477        assert!(content.contains("hello"));
478        let content = TestClient::get("http://127.0.0.1:5801/hello")
479            .add_header("Cookie", format!("jwt_token={token}"), true)
480            .send(&service)
481            .await
482            .take_string()
483            .await
484            .unwrap();
485        assert!(content.contains("hello"));
486
487        let token = jsonwebtoken::encode(
488            &jsonwebtoken::Header::default(),
489            &claim,
490            &EncodingKey::from_secret(b"ABCDEFG"),
491        )
492        .unwrap();
493        let content = access(&service, &token).await;
494        assert!(content.contains("Forbidden"));
495    }
496}