salvo_jwt_auth/lib.rs
1//! Provides JWT (JSON Web Token) authentication support for the Salvo web framework.
2//!
3//! This crate helps you implement JWT-based authentication in your Salvo web applications.
4//! It offers flexible token extraction from various sources (headers, query parameters, cookies,
5//! etc.) and multiple decoding strategies.
6//!
7//! # Features
8//!
9//! - Extract JWT tokens from multiple sources (headers, query parameters, cookies, forms)
10//! - Configurable token validation
11//! - OpenID Connect support (behind the `oidc` feature flag)
12//! - Seamless integration with Salvo's middleware system
13//!
14//! # Example:
15//!
16//! ```no_run
17//! use jsonwebtoken::{self, EncodingKey};
18//! use salvo::http::{Method, StatusError};
19//! use salvo::jwt_auth::{ConstDecoder, QueryFinder};
20//! use salvo::prelude::*;
21//! use serde::{Deserialize, Serialize};
22//! use time::{Duration, OffsetDateTime};
23//!
24//! const SECRET_KEY: &str = "YOUR_SECRET_KEY"; // In production, use a secure key management solution
25//!
26//! #[derive(Serialize, Deserialize, Clone, Debug)]
27//! pub struct JwtClaims {
28//! username: String,
29//! exp: i64,
30//! }
31//!
32//! #[tokio::main]
33//! async fn main() {
34//! let auth_handler: JwtAuth<JwtClaims, _> = JwtAuth::new(ConstDecoder::from_secret(SECRET_KEY.as_bytes()))
35//! .finders(vec![
36//! // Box::new(HeaderFinder::new()),
37//! Box::new(QueryFinder::new("jwt_token")),
38//! // Box::new(CookieFinder::new("jwt_token")),
39//! ])
40//! .force_passed(true);
41//!
42//! let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
43//! Server::new(acceptor)
44//! .serve(Router::with_hoop(auth_handler).goal(index))
45//! .await;
46//! }
47//! #[handler]
48//! async fn index(req: &mut Request, depot: &mut Depot, res: &mut Response) -> anyhow::Result<()> {
49//! if req.method() == Method::POST {
50//! let (username, password) = (
51//! req.form::<String>("username").await.unwrap_or_default(),
52//! req.form::<String>("password").await.unwrap_or_default(),
53//! );
54//! if !validate(&username, &password) {
55//! res.render(Text::Html(LOGIN_HTML));
56//! return Ok(());
57//! }
58//! let exp = OffsetDateTime::now_utc() + Duration::days(14);
59//! let claim = JwtClaims {
60//! username,
61//! exp: exp.unix_timestamp(),
62//! };
63//! let token = jsonwebtoken::encode(
64//! &jsonwebtoken::Header::default(),
65//! &claim,
66//! &EncodingKey::from_secret(SECRET_KEY.as_bytes()),
67//! )?;
68//! res.render(Redirect::other(format!("/?jwt_token={token}")));
69//! } else {
70//! match depot.jwt_auth_state() {
71//! JwtAuthState::Authorized => {
72//! let data = depot.jwt_auth_data::<JwtClaims>().unwrap();
73//! res.render(Text::Plain(format!(
74//! "Hi {}, you have logged in successfully!",
75//! data.claims.username
76//! )));
77//! }
78//! JwtAuthState::Unauthorized => {
79//! res.render(Text::Html(LOGIN_HTML));
80//! }
81//! JwtAuthState::Forbidden => {
82//! res.render(StatusError::forbidden());
83//! }
84//! }
85//! }
86//! Ok(())
87//! }
88//!
89//! fn validate(username: &str, password: &str) -> bool {
90//! // In a real application, use secure password verification
91//! username == "root" && password == "pwd"
92//! }
93//!
94//! static LOGIN_HTML: &str = r#"<!DOCTYPE html>
95//! <html>
96//! <head>
97//! <title>JWT Auth Demo</title>
98//! </head>
99//! <body>
100//! <h1>JWT Auth</h1>
101//! <form action="/" method="post">
102//! <label for="username"><b>Username</b></label>
103//! <input type="text" placeholder="Enter Username" name="username" required>
104//!
105//! <label for="password"><b>Password</b></label>
106//! <input type="password" placeholder="Enter Password" name="password" required>
107//!
108//! <button type="submit">Login</button>
109//! </form>
110//! </body>
111//! </html>
112//! "#;
113//! ```
114
115#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
116#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
117#![cfg_attr(docsrs, feature(doc_cfg))]
118
119use std::fmt::{self, Debug, Formatter};
120use std::marker::PhantomData;
121
122#[doc(no_inline)]
123pub use jsonwebtoken::{
124 Algorithm, DecodingKey, TokenData, Validation, decode, errors::Error as JwtError,
125};
126use salvo_core::http::{Method, Request, Response, StatusError};
127use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
128use serde::de::DeserializeOwned;
129use thiserror::Error;
130
131mod finder;
132pub use finder::{CookieFinder, FormFinder, HeaderFinder, JwtTokenFinder, QueryFinder};
133
134mod decoder;
135pub use decoder::{ConstDecoder, JwtAuthDecoder};
136
137#[macro_use]
138mod cfg;
139
140cfg_feature! {
141 #![feature = "oidc"]
142 pub mod oidc;
143 pub use oidc::OidcDecoder;
144}
145
146/// key used to insert auth decoded data to depot.
147pub const JWT_AUTH_DATA_KEY: &str = "::salvo::jwt_auth::auth_data";
148/// key used to insert auth state data to depot.
149pub const JWT_AUTH_STATE_KEY: &str = "::salvo::jwt_auth::auth_state";
150/// key used to insert auth token data to depot.
151pub const JWT_AUTH_TOKEN_KEY: &str = "::salvo::jwt_auth::auth_token";
152/// key used to insert auth error to depot.
153pub const JWT_AUTH_ERROR_KEY: &str = "::salvo::jwt_auth::auth_error";
154
155const ALL_METHODS: [Method; 9] = [
156 Method::GET,
157 Method::POST,
158 Method::PUT,
159 Method::DELETE,
160 Method::HEAD,
161 Method::OPTIONS,
162 Method::CONNECT,
163 Method::PATCH,
164 Method::TRACE,
165];
166
167/// JwtAuthError
168#[derive(Debug, Error)]
169pub enum JwtAuthError {
170 /// HTTP client error
171 #[cfg(feature = "oidc")]
172 #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))]
173 #[error("ClientError")]
174 ClientError(#[from] hyper_util::client::legacy::Error),
175
176 /// Error occurred in hyper.
177 #[cfg(feature = "oidc")]
178 #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))]
179 #[error("HyperError")]
180 Hyper(#[from] salvo_core::hyper::Error),
181
182 /// InvalidUri
183 #[error("InvalidUri")]
184 InvalidUri(#[from] salvo_core::http::uri::InvalidUri),
185 /// Serde error
186 #[error("Serde error")]
187 SerdeError(#[from] serde_json::Error),
188 /// Failed to discover OIDC configuration
189 #[error("Failed to discover OIDC configuration")]
190 DiscoverError,
191 /// Decoding of JWKS error
192 #[error("Decoding of JWKS error")]
193 DecodeError(#[from] base64::DecodeError),
194 /// JWT is missing kid, alg, or decoding components
195 #[error("JWT is missing kid, alg, or decoding components")]
196 InvalidJwk,
197 /// Issuer URL invalid
198 #[error("Issuer URL invalid")]
199 IssuerParseError,
200 /// Failure of validating the token. See [jsonwebtoken::errors::ErrorKind] for possible reasons
201 /// this value could be returned Would typically result in a 401 HTTP Status code
202 #[error("JWT Is Invalid")]
203 ValidationFailed(#[from] jsonwebtoken::errors::Error),
204 /// Failure to re-validate the JWKS.
205 /// Would typically result in a 401 or 500 status code depending on preference
206 #[error("Token was unable to be validated due to cache expiration")]
207 CacheError,
208 /// Token did not contain a kid in its header and would be impossible to validate
209 /// Would typically result in a 401 HTTP Status code
210 #[error("Token did not contain a KID field")]
211 MissingKid,
212}
213
214/// Possible states of JWT authentication.
215///
216/// The middleware sets this state in the depot after processing a request.
217/// You can access it via `depot.jwt_auth_state()`.
218#[derive(Copy, Clone, Eq, PartialEq, Debug)]
219pub enum JwtAuthState {
220 /// Authentication was successful and the token was valid.
221 Authorized,
222 /// No token was provided in the request.
223 /// Usually results in a 401 Unauthorized response unless `force_passed` is true.
224 Unauthorized,
225 /// A token was provided but it failed validation.
226 /// Usually results in a 403 Forbidden response unless `force_passed` is true.
227 Forbidden,
228}
229
230/// Extension trait for accessing JWT authentication data from the depot.
231///
232/// This trait provides convenient methods to retrieve JWT authentication information
233/// that was previously stored in the depot by the `JwtAuth` middleware.
234pub trait JwtAuthDepotExt {
235 /// Gets the JWT token string from the depot.
236 fn jwt_auth_token(&self) -> Option<&str>;
237
238 /// Gets the decoded JWT claims data from the depot.
239 ///
240 /// The generic parameter `C` should be the same type used when configuring the `JwtAuth`
241 /// middleware.
242 fn jwt_auth_data<C>(&self) -> Option<&TokenData<C>>
243 where
244 C: DeserializeOwned + Send + Sync + 'static;
245
246 /// Gets the current JWT authentication state from the depot.
247 ///
248 /// Returns `JwtAuthState::Unauthorized` if no state is present in the depot.
249 fn jwt_auth_state(&self) -> JwtAuthState;
250
251 /// Gets the JWT error if authentication failed.
252 fn jwt_auth_error(&self) -> Option<&JwtError>;
253}
254
255impl JwtAuthDepotExt for Depot {
256 #[inline]
257 fn jwt_auth_token(&self) -> Option<&str> {
258 self.get::<String>(JWT_AUTH_TOKEN_KEY).map(|v| &**v).ok()
259 }
260
261 #[inline]
262 fn jwt_auth_data<C>(&self) -> Option<&TokenData<C>>
263 where
264 C: DeserializeOwned + Send + Sync + 'static,
265 {
266 self.get(JWT_AUTH_DATA_KEY).ok()
267 }
268
269 #[inline]
270 fn jwt_auth_state(&self) -> JwtAuthState {
271 self.get(JWT_AUTH_STATE_KEY)
272 .ok()
273 .cloned()
274 .unwrap_or(JwtAuthState::Unauthorized)
275 }
276
277 #[inline]
278 fn jwt_auth_error(&self) -> Option<&JwtError> {
279 self.get(JWT_AUTH_ERROR_KEY).ok()
280 }
281}
282
283/// JWT Authentication middleware for Salvo.
284///
285/// `JwtAuth` extracts and validates JWT tokens from incoming requests based on the configured
286/// token finders and decoder. If valid, it stores the decoded data in the depot for later use.
287///
288/// # Type Parameters
289///
290/// * `C` - The claims type that will be deserialized from the JWT payload.
291/// * `D` - The decoder implementation used to validate and decode the JWT token.
292#[non_exhaustive]
293pub struct JwtAuth<C, D> {
294 /// When set to `true`, the middleware will allow the request to proceed even if
295 /// authentication fails, storing only the authentication state in the depot.
296 ///
297 /// When set to `false` (default), requests with invalid or missing tokens will be
298 /// immediately rejected with appropriate status codes.
299 pub force_passed: bool,
300 _claims: PhantomData<C>,
301 /// The decoder used to validate and decode the JWT token.
302 pub decoder: D,
303 /// A list of token finders that will be used to extract the token from the request.
304 /// Finders are tried in order until one returns a token.
305 pub finders: Vec<Box<dyn JwtTokenFinder>>,
306}
307impl<C, D> Debug for JwtAuth<C, D>
308where
309 C: DeserializeOwned + Send + Sync + 'static,
310 D: JwtAuthDecoder + Send + Sync + 'static,
311{
312 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
313 f.debug_struct("JwtAuth")
314 .field("force_passed", &self.force_passed)
315 .finish()
316 }
317}
318
319impl<C, D> JwtAuth<C, D>
320where
321 C: DeserializeOwned + Send + Sync + 'static,
322 D: JwtAuthDecoder + Send + Sync + 'static,
323{
324 /// Create new `JwtAuth`.
325 #[inline]
326 #[must_use]
327 pub fn new(decoder: D) -> Self {
328 Self {
329 force_passed: false,
330 decoder,
331 _claims: PhantomData::<C>,
332 finders: vec![Box::new(HeaderFinder::new())],
333 }
334 }
335 /// Sets force_passed value and return Self.
336 #[inline]
337 #[must_use]
338 pub fn force_passed(mut self, force_passed: bool) -> Self {
339 self.force_passed = force_passed;
340 self
341 }
342
343 /// Get decoder mutable reference.
344 #[inline]
345 pub fn decoder_mut(&mut self) -> &mut D {
346 &mut self.decoder
347 }
348
349 /// Gets a mutable reference to the extractor list.
350 #[inline]
351 pub fn finders_mut(&mut self) -> &mut Vec<Box<dyn JwtTokenFinder>> {
352 &mut self.finders
353 }
354 /// Sets extractor list with new value and return Self.
355 #[inline]
356 #[must_use]
357 pub fn finders(mut self, finders: Vec<Box<dyn JwtTokenFinder>>) -> Self {
358 self.finders = finders;
359 self
360 }
361
362 async fn find_token(&self, req: &mut Request) -> Option<String> {
363 for finder in &self.finders {
364 if let Some(token) = finder.find_token(req).await {
365 return Some(token);
366 }
367 }
368 None
369 }
370}
371
372#[async_trait]
373impl<C, D> Handler for JwtAuth<C, D>
374where
375 C: DeserializeOwned + Clone + Send + Sync + 'static,
376 D: JwtAuthDecoder + Send + Sync + 'static,
377{
378 async fn handle(
379 &self,
380 req: &mut Request,
381 depot: &mut Depot,
382 res: &mut Response,
383 ctrl: &mut FlowCtrl,
384 ) {
385 let token = self.find_token(req).await;
386 if let Some(token) = token {
387 match self.decoder.decode::<C>(&token, depot).await {
388 Ok(data) => {
389 depot.insert(JWT_AUTH_DATA_KEY, data);
390 depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Authorized);
391 depot.insert(JWT_AUTH_TOKEN_KEY, token);
392 }
393 Err(e) => {
394 tracing::info!(error = ?e, "jwt auth error");
395 depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Forbidden);
396 depot.insert(JWT_AUTH_ERROR_KEY, e);
397 if !self.force_passed {
398 res.render(StatusError::forbidden());
399 ctrl.skip_rest();
400 }
401 }
402 }
403 } else {
404 depot.insert(JWT_AUTH_STATE_KEY, JwtAuthState::Unauthorized);
405 if !self.force_passed {
406 res.render(StatusError::unauthorized());
407 ctrl.skip_rest();
408 }
409 }
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use jsonwebtoken::EncodingKey;
416 use salvo_core::prelude::*;
417 use salvo_core::test::{ResponseExt, TestClient};
418 use serde::{Deserialize, Serialize};
419 use time::{Duration, OffsetDateTime};
420
421 use super::*;
422
423 #[derive(Serialize, Deserialize, Clone, Debug)]
424 struct JwtClaims {
425 user: String,
426 exp: i64,
427 }
428 #[tokio::test]
429 async fn test_jwt_auth() {
430 let auth_handler: JwtAuth<JwtClaims, ConstDecoder> =
431 JwtAuth::new(ConstDecoder::from_secret(b"ABCDEF")).finders(vec![
432 Box::new(HeaderFinder::new()),
433 Box::new(QueryFinder::new("jwt_token")),
434 Box::new(CookieFinder::new("jwt_token")),
435 ]);
436
437 #[handler]
438 async fn hello() -> &'static str {
439 "hello"
440 }
441
442 let router = Router::new()
443 .hoop(auth_handler)
444 .push(Router::with_path("hello").get(hello));
445 let service = Service::new(router);
446
447 async fn access(service: &Service, token: &str) -> String {
448 TestClient::get("http://127.0.0.1:5801/hello")
449 .add_header("Authorization", format!("Bearer {token}"), true)
450 .send(service)
451 .await
452 .take_string()
453 .await
454 .unwrap()
455 }
456
457 let claim = JwtClaims {
458 user: "root".into(),
459 exp: (OffsetDateTime::now_utc() + Duration::days(1)).unix_timestamp(),
460 };
461
462 let token = jsonwebtoken::encode(
463 &jsonwebtoken::Header::default(),
464 &claim,
465 &EncodingKey::from_secret(b"ABCDEF"),
466 )
467 .unwrap();
468 let content = access(&service, &token).await;
469 assert!(content.contains("hello"));
470
471 let content = TestClient::get(format!("http://127.0.0.1:5801/hello?jwt_token={token}"))
472 .send(&service)
473 .await
474 .take_string()
475 .await
476 .unwrap();
477 assert!(content.contains("hello"));
478 let content = TestClient::get("http://127.0.0.1:5801/hello")
479 .add_header("Cookie", format!("jwt_token={token}"), true)
480 .send(&service)
481 .await
482 .take_string()
483 .await
484 .unwrap();
485 assert!(content.contains("hello"));
486
487 let token = jsonwebtoken::encode(
488 &jsonwebtoken::Header::default(),
489 &claim,
490 &EncodingKey::from_secret(b"ABCDEFG"),
491 )
492 .unwrap();
493 let content = access(&service, &token).await;
494 assert!(content.contains("Forbidden"));
495 }
496}