1use anyhow::Result;
2use async_trait::async_trait;
3use axum::{
4 extract::Request,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use axum_help::filter::{drain_body, AsyncPredicate};
9use headers::{authorization::Basic, Authorization, HeaderMapExt};
10use parking_lot::Mutex;
11use std::{collections::VecDeque, fmt::Display, future::Future, pin::Pin, sync::Arc};
12
13use self::digest::unauthorized;
14
15#[async_trait]
16pub trait AuthCheckPredicate {
17 type CheckInfo: Clone + Send + Sync + 'static;
18
19 async fn check(
20 &self,
21 username: impl Into<String> + Send,
22 password: impl Into<String> + Send,
23 ) -> Result<Self::CheckInfo>;
24
25 fn username(&self) -> &str;
26 fn password(&self) -> &str;
27}
28
29#[derive(Clone)]
30pub struct AsyncBasicAuth<T>(T, String)
31where
32 T: AuthCheckPredicate + Clone + Send;
33
34impl<T> AsyncBasicAuth<T>
35where
36 T: AuthCheckPredicate + Clone + Send,
37{
38 pub fn new(p: T) -> Self {
39 Self(p, "Need basic authenticate".to_string())
40 }
41
42 pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
43 self.1 = msg.into();
44 self
45 }
46}
47
48impl<T> AsyncPredicate<Request> for AsyncBasicAuth<T>
49where
50 T: AuthCheckPredicate + Clone + Send + Sync + 'static,
51{
52 type Request = Request;
53 type Response = Response;
54 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
55
56 fn check(&mut self, mut request: Request) -> Self::Future {
57 let mut err = self.1.clone();
58 let auth = self.0.clone();
59 Box::pin(async move {
60 if let Some(authorization) = request.headers().typed_get::<Authorization<Basic>>() {
61 match auth
62 .check(authorization.username(), authorization.password())
63 .await
64 {
65 Err(e) => err = format!("check authorization error: {:?}", e),
66 Ok(ci) => {
67 request.extensions_mut().insert(ci);
68 return Ok(request);
69 }
70 }
71 }
72
73 drain_body(request).await;
74 Err((
75 StatusCode::UNAUTHORIZED,
76 [("WWW-Authenticate", "Basic"); 1],
77 err,
78 )
79 .into_response())
80 })
81 }
82}
83
84#[derive(Clone)]
85pub struct AsyncDigestAuth<T>
86where
87 T: AuthCheckPredicate + Clone + Send,
88{
89 inner: T,
90 err: String,
91 srv_name: String,
92 nonces: Arc<Mutex<VecDeque<(String, String)>>>,
93}
94
95impl<T> AsyncDigestAuth<T>
96where
97 T: AuthCheckPredicate + Clone + Send,
98{
99 pub fn new(p: T) -> Self {
100 Self {
101 inner: p,
102 srv_name: env!("CARGO_PKG_NAME").to_owned(),
103 err: "Need digest authenticate".to_string(),
104 nonces: Arc::new(Mutex::new(VecDeque::new())),
105 }
106 }
107
108 pub fn srv_name(mut self, name: impl Into<String>) -> Self {
109 self.srv_name = name.into();
110 self
111 }
112
113 pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
114 self.err = msg.into();
115 self
116 }
117}
118
119impl<T> AsyncPredicate<Request> for AsyncDigestAuth<T>
120where
121 T: AuthCheckPredicate + Clone + Send + Sync + 'static,
122{
123 type Request = Request;
124 type Response = Response;
125 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
126
127 fn check(&mut self, request: Request) -> Self::Future {
128 let err = self.err.clone();
129 let inner = self.inner.clone();
130 let srv_name = self.srv_name.clone();
131 let nonces = self.nonces.clone();
132 Box::pin(async move {
133 if let Some(auth_header) = request.headers().get("Authorization") {
134 let auth = digest::Authorization::from_header(
135 auth_header.to_str().map_err(bad_request)?,
136 )
137 .map_err(bad_request)?;
138
139 return auth.check(
140 inner.username(),
141 inner.password(),
142 nonces,
143 request,
144 srv_name,
145 );
146 }
147
148 drain_body(request).await;
149 Err(unauthorized(nonces, err, srv_name))
150 })
151 }
152}
153
154fn bad_request(e: impl Display) -> Response {
155 (
156 StatusCode::BAD_REQUEST,
157 format!("Bad request in header Authorization: {}", e),
158 )
159 .into_response()
160}
161
162mod digest {
163 use anyhow::{anyhow, bail, Result};
164 use axum::{
165 extract::Request,
166 http::StatusCode,
167 response::{IntoResponse, Response},
168 };
169 use parking_lot::Mutex;
170 use rand::{distributions::Alphanumeric, thread_rng, Rng};
171 use std::{collections::VecDeque, fmt::Debug, sync::Arc};
172
173 #[derive(Default, Debug)]
174 pub(super) struct Authorization {
175 pub(super) username: String,
176 pub(super) realm: String,
177 pub(super) nonce: String,
178 pub(super) uri: String,
179 pub(super) qop: String,
180 pub(super) nc: String,
181 pub(super) cnonce: String,
182 pub(super) response: String,
183 pub(super) opaque: String,
184 }
185
186 impl Authorization {
187 pub(super) fn check(
188 &self,
189 username: impl AsRef<str>,
190 password: impl AsRef<str>,
191 nonces: Arc<Mutex<VecDeque<(String, String)>>>,
192 request: Request,
193 srv_name: impl AsRef<str>,
194 ) -> Result<Request, Response> {
195 let mut found_nonce = false;
196 {
197 let mut nonce_list = nonces.lock();
198 let mut index = nonce_list.len().saturating_sub(1);
199
200 for (nonce, opaque) in nonce_list.iter().rev() {
201 if nonce == &self.nonce || opaque == &self.opaque {
202 found_nonce = true;
203 nonce_list.remove(index);
204 break;
205 }
206
207 index = index.saturating_sub(1);
208 }
209 }
210
211 if !found_nonce {
212 return Err(unauthorized(nonces, "invalid nonce or opaque", srv_name));
213 }
214
215 log::debug!("digest request: {:?}", request);
216 let ha1 = md5::compute(format!(
217 "{}:{}:{}",
218 username.as_ref(),
219 self.realm,
220 password.as_ref()
221 ));
222 let ha2 = md5::compute(format!("{}:{}", request.method(), self.uri));
223 let password = md5::compute(format!(
224 "{:x}:{}:{}:{}:{}:{:x}",
225 ha1, self.nonce, self.nc, self.cnonce, self.qop, ha2
226 ));
227
228 if format!("{:x}", password) != self.response {
229 return Err(unauthorized(
230 nonces,
231 "invalid username or password",
232 srv_name,
233 ));
234 }
235
236 Ok(request)
237 }
238
239 const DIGEST_MARK: &'static str = "Digest";
240 pub(super) fn from_header(auth: impl AsRef<str>) -> Result<Self> {
241 let auth = auth.as_ref();
242 let (mark, content) = auth.split_at(Self::DIGEST_MARK.len());
243 let content = content.trim();
244 if mark != Self::DIGEST_MARK {
245 bail!("only support digest authorization");
246 }
247
248 let mut result = Authorization::default();
249 for c in content.split(',') {
250 let c = c.trim();
251 let (k, v) = c
252 .split_once('=')
253 .ok_or_else(|| anyhow!("invalid part of authorization: {}", c))?;
254 let v = v.trim_matches('"');
255 match k {
256 "username" => result.username = v.to_string(),
257 "realm" => result.realm = v.to_string(),
258 "nonce" => result.nonce = v.to_string(),
259 "uri" => result.uri = v.to_string(),
260 "qop" => result.qop = v.to_string(),
261 "nc" => result.nc = v.to_string(),
262 "cnonce" => result.cnonce = v.to_string(),
263 "response" => result.response = v.to_string(),
264 "opaque" => result.opaque = v.to_string(),
265 _ => {
266 log::warn!("unknown authorization part: {}", c);
267 continue;
268 }
269 }
270 }
271
272 log::debug!("digest auth: {:?}", result);
273 Ok(result)
274 }
275 }
276
277 pub(super) fn unauthorized(
278 nonces: Arc<Mutex<VecDeque<(String, String)>>>,
279 msg: impl Into<String>,
280 srv_name: impl AsRef<str>,
281 ) -> Response {
282 let realm = format!("Login to {}", srv_name.as_ref());
283 let nonce = rand_string(32);
284 let opaque = rand_string(32);
285
286 let www_authenticate = format!(
287 r#"Digest realm="{}",qop="auth",nonce="{}",opaque="{}""#,
288 realm, nonce, opaque
289 );
290
291 {
292 let mut nonce_list = nonces.lock();
293 while nonce_list.len() >= 256 {
294 nonce_list.pop_front();
295 }
296
297 nonce_list.push_back((nonce, opaque));
298 }
299
300 (
301 StatusCode::UNAUTHORIZED,
302 [("WWW-Authenticate", www_authenticate); 1],
303 msg.into(),
304 )
305 .into_response()
306 }
307
308 fn rand_string(count: usize) -> String {
309 thread_rng()
310 .sample_iter(Alphanumeric)
311 .take(count)
312 .map(char::from)
313 .collect()
314 }
315}