salvo_extra/
basic_auth.rs1use base64::engine::{general_purpose, Engine};
34use salvo_core::http::header::{HeaderName, AUTHORIZATION, PROXY_AUTHORIZATION};
35use salvo_core::http::{Request, Response, StatusCode};
36use salvo_core::{async_trait, Depot, Error, FlowCtrl, Handler};
37
38pub const USERNAME_KEY: &str = "::salvo::basic_auth::username";
40
41pub trait BasicAuthValidator: Send + Sync {
43 fn validate(&self, username: &str, password: &str, depot: &mut Depot) -> impl Future<Output = bool> + Send;
48}
49
50pub trait BasicAuthDepotExt {
52 fn basic_auth_username(&self) -> Option<&str>;
54}
55
56impl BasicAuthDepotExt for Depot {
57 fn basic_auth_username(&self) -> Option<&str> {
58 self.get::<String>(USERNAME_KEY).map(|v|&**v).ok()
59 }
60}
61
62pub struct BasicAuth<V: BasicAuthValidator> {
64 realm: String,
65 header_names: Vec<HeaderName>,
66 validator: V,
67}
68
69impl<V> BasicAuth<V>
70where
71 V: BasicAuthValidator,
72{
73 #[inline]
75 pub fn new(validator: V) -> Self {
76 BasicAuth {
77 realm: "realm".to_owned(),
78 header_names: vec![AUTHORIZATION, PROXY_AUTHORIZATION],
79 validator,
80 }
81 }
82
83 #[doc(hidden)]
84 #[inline]
85 pub fn set_header_names(mut self, header_names: impl Into<Vec<HeaderName>>) -> Self {
86 self.header_names = header_names.into();
87 self
88 }
89 #[doc(hidden)]
90 #[inline]
91 pub fn header_names(&self) -> &Vec<HeaderName> {
92 &self.header_names
93 }
94
95 #[doc(hidden)]
96 #[inline]
97 pub fn header_names_mut(&mut self) -> &mut Vec<HeaderName> {
98 &mut self.header_names
99 }
100
101 #[doc(hidden)]
102 #[inline]
103 pub fn ask_credentials(&self, res: &mut Response) {
104 ask_credentials(res, &self.realm)
105 }
106
107 #[doc(hidden)]
108 #[inline]
109 pub fn parse_credentials(&self, req: &Request) -> Result<(String, String), Error> {
110 parse_credentials(req, &self.header_names)
111 }
112}
113
114#[doc(hidden)]
115#[inline]
116pub fn ask_credentials(res: &mut Response, realm: impl AsRef<str>) {
117 res.headers_mut().insert(
118 "WWW-Authenticate",
119 format!("Basic realm={:?}", realm.as_ref())
120 .parse()
121 .expect("parse WWW-Authenticate failed"),
122 );
123 res.status_code(StatusCode::UNAUTHORIZED);
124}
125
126#[doc(hidden)]
127pub fn parse_credentials(req: &Request, header_names: &[HeaderName]) -> Result<(String, String), Error> {
128 let mut authorization = "";
129 for header_name in header_names {
130 if let Some(header_value) = req.headers().get(header_name) {
131 authorization = header_value.to_str().unwrap_or_default();
132 if !authorization.is_empty() {
133 break;
134 }
135 }
136 }
137
138 if authorization.starts_with("Basic") {
139 if let Some((_, auth)) = authorization.split_once(' ') {
140 let auth = general_purpose::STANDARD.decode(auth).map_err(Error::other)?;
141 let auth = auth.iter().map(|&c| c as char).collect::<String>();
142 if let Some((username, password)) = auth.split_once(':') {
143 return Ok((username.to_owned(), password.to_owned()));
144 } else {
145 return Err(Error::other("`authorization` has bad format"));
146 }
147 }
148 }
149 Err(Error::other("parse http header failed"))
150}
151
152#[async_trait]
153impl<V> Handler for BasicAuth<V>
154where
155 V: BasicAuthValidator + 'static,
156{
157 async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
158 if let Ok((username, password)) = self.parse_credentials(req) {
159 if self.validator.validate(&username, &password, depot).await {
160 depot.insert(USERNAME_KEY, username);
161 ctrl.call_next(req, depot, res).await;
162 return;
163 }
164 }
165 self.ask_credentials(res);
166 ctrl.skip_rest();
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use salvo_core::prelude::*;
173 use salvo_core::test::{ResponseExt, TestClient};
174
175 use super::*;
176
177 #[handler]
178 async fn hello() -> &'static str {
179 "Hello"
180 }
181
182 struct Validator;
183 impl BasicAuthValidator for Validator {
184 async fn validate(&self, username: &str, password: &str, _depot: &mut Depot) -> bool {
185 username == "root" && password == "pwd"
186 }
187 }
188
189 #[tokio::test]
190 async fn test_basic_auth() {
191 let auth_handler = BasicAuth::new(Validator);
192 let router = Router::with_hoop(auth_handler).goal(hello);
193 let service = Service::new(router);
194
195 let content = TestClient::get("http://127.0.0.1:5800/")
196 .basic_auth("root", Some("pwd"))
197 .send(&service)
198 .await
199 .take_string()
200 .await
201 .unwrap();
202 assert!(content.contains("Hello"));
203
204 let content = TestClient::get("http://127.0.0.1:5800/")
205 .basic_auth("root", Some("pwd2"))
206 .send(&service)
207 .await
208 .take_string()
209 .await
210 .unwrap();
211 assert!(content.contains("Unauthorized"));
212 }
213}