salvo_extra/
basic_auth.rs

1//! Middleware for HTTP Basic Authentication.
2//!
3//! This middleware implements the standard HTTP Basic Authentication scheme as described in RFC 7617.
4//! It extracts credentials from the Authorization header and validates them against your custom validator.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use salvo_core::prelude::*;
10//! use salvo_extra::basic_auth::{BasicAuth, BasicAuthValidator};
11//!
12//! struct Validator;
13//! impl BasicAuthValidator for Validator {
14//!     async fn validate(&self, username: &str, password: &str, _depot: &mut Depot) -> bool {
15//!         username == "root" && password == "pwd"
16//!     }
17//! }
18//! 
19//! #[handler]
20//! async fn hello() -> &'static str {
21//!     "Hello"
22//! }
23//!
24//! #[tokio::main]
25//! async fn main() {
26//!     let auth_handler = BasicAuth::new(Validator);
27//!     let router = Router::with_hoop(auth_handler).goal(hello);
28//!
29//!     let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
30//!     Server::new(acceptor).serve(router).await;
31//! }
32//! ```
33use base64::engine::{general_purpose, Engine};
34use salvo_core::http::header::{HeaderName, AUTHORIZATION, PROXY_AUTHORIZATION};
35use salvo_core::http::{Request, Response, StatusCode};
36use salvo_core::{async_trait, Depot, Error, FlowCtrl, Handler};
37
38/// key used when insert into depot.
39pub const USERNAME_KEY: &str = "::salvo::basic_auth::username";
40
41/// Validator for Basic Authentication credentials.
42pub trait BasicAuthValidator: Send + Sync {
43    /// Validates whether the provided username and password are correct.
44    /// 
45    /// Implement this method to check credentials against your authentication system.
46    /// Return `true` if authentication succeeds, `false` otherwise.
47    fn validate(&self, username: &str, password: &str, depot: &mut Depot) -> impl Future<Output = bool> + Send;
48}
49
50/// Extension trait for retrieving the authenticated username from a Depot.
51pub trait BasicAuthDepotExt {
52    /// Returns the authenticated username if authentication was successful.
53    fn basic_auth_username(&self) -> Option<&str>;
54}
55
56impl BasicAuthDepotExt for Depot {
57    fn basic_auth_username(&self) -> Option<&str> {
58        self.get::<String>(USERNAME_KEY).map(|v|&**v).ok()
59    }
60}
61
62/// BasicAuth
63pub struct BasicAuth<V: BasicAuthValidator> {
64    realm: String,
65    header_names: Vec<HeaderName>,
66    validator: V,
67}
68
69impl<V> BasicAuth<V>
70where
71    V: BasicAuthValidator,
72{
73    /// Create new `BasicAuthValidator`.
74    #[inline]
75    pub fn new(validator: V) -> Self {
76        BasicAuth {
77            realm: "realm".to_owned(),
78            header_names: vec![AUTHORIZATION, PROXY_AUTHORIZATION],
79            validator,
80        }
81    }
82
83    #[doc(hidden)]
84    #[inline]
85    pub fn set_header_names(mut self, header_names: impl Into<Vec<HeaderName>>) -> Self {
86        self.header_names = header_names.into();
87        self
88    }
89    #[doc(hidden)]
90    #[inline]
91    pub fn header_names(&self) -> &Vec<HeaderName> {
92        &self.header_names
93    }
94
95    #[doc(hidden)]
96    #[inline]
97    pub fn header_names_mut(&mut self) -> &mut Vec<HeaderName> {
98        &mut self.header_names
99    }
100
101    #[doc(hidden)]
102    #[inline]
103    pub fn ask_credentials(&self, res: &mut Response) {
104        ask_credentials(res, &self.realm)
105    }
106
107    #[doc(hidden)]
108    #[inline]
109    pub fn parse_credentials(&self, req: &Request) -> Result<(String, String), Error> {
110        parse_credentials(req, &self.header_names)
111    }
112}
113
114#[doc(hidden)]
115#[inline]
116pub fn ask_credentials(res: &mut Response, realm: impl AsRef<str>) {
117    res.headers_mut().insert(
118        "WWW-Authenticate",
119        format!("Basic realm={:?}", realm.as_ref())
120            .parse()
121            .expect("parse WWW-Authenticate failed"),
122    );
123    res.status_code(StatusCode::UNAUTHORIZED);
124}
125
126#[doc(hidden)]
127pub fn parse_credentials(req: &Request, header_names: &[HeaderName]) -> Result<(String, String), Error> {
128    let mut authorization = "";
129    for header_name in header_names {
130        if let Some(header_value) = req.headers().get(header_name) {
131            authorization = header_value.to_str().unwrap_or_default();
132            if !authorization.is_empty() {
133                break;
134            }
135        }
136    }
137
138    if authorization.starts_with("Basic") {
139        if let Some((_, auth)) = authorization.split_once(' ') {
140            let auth = general_purpose::STANDARD.decode(auth).map_err(Error::other)?;
141            let auth = auth.iter().map(|&c| c as char).collect::<String>();
142            if let Some((username, password)) = auth.split_once(':') {
143                return Ok((username.to_owned(), password.to_owned()));
144            } else {
145                return Err(Error::other("`authorization` has bad format"));
146            }
147        }
148    }
149    Err(Error::other("parse http header failed"))
150}
151
152#[async_trait]
153impl<V> Handler for BasicAuth<V>
154where
155    V: BasicAuthValidator + 'static,
156{
157    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
158        if let Ok((username, password)) = self.parse_credentials(req) {
159            if self.validator.validate(&username, &password, depot).await {
160                depot.insert(USERNAME_KEY, username);
161                ctrl.call_next(req, depot, res).await;
162                return;
163            }
164        }
165        self.ask_credentials(res);
166        ctrl.skip_rest();
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use salvo_core::prelude::*;
173    use salvo_core::test::{ResponseExt, TestClient};
174
175    use super::*;
176
177    #[handler]
178    async fn hello() -> &'static str {
179        "Hello"
180    }
181
182    struct Validator;
183    impl BasicAuthValidator for Validator {
184        async fn validate(&self, username: &str, password: &str, _depot: &mut Depot) -> bool {
185            username == "root" && password == "pwd"
186        }
187    }
188
189    #[tokio::test]
190    async fn test_basic_auth() {
191        let auth_handler = BasicAuth::new(Validator);
192        let router = Router::with_hoop(auth_handler).goal(hello);
193        let service = Service::new(router);
194
195        let content = TestClient::get("http://127.0.0.1:5800/")
196            .basic_auth("root", Some("pwd"))
197            .send(&service)
198            .await
199            .take_string()
200            .await
201            .unwrap();
202        assert!(content.contains("Hello"));
203
204        let content = TestClient::get("http://127.0.0.1:5800/")
205            .basic_auth("root", Some("pwd2"))
206            .send(&service)
207            .await
208            .take_string()
209            .await
210            .unwrap();
211        assert!(content.contains("Unauthorized"));
212    }
213}