1use self::digest::unauthorized;
2use anyhow::Result;
3use axum::{
4 extract::Request,
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use axum_help::filter::{drain_body, AsyncPredicate};
9use headers::{authorization::Basic, Authorization, HeaderMapExt};
10use parking_lot::Mutex;
11use std::{collections::VecDeque, fmt::Display, future::Future, pin::Pin, sync::Arc};
12
13pub trait AuthCheckPredicate {
14 type CheckInfo: Clone + Send + Sync + 'static;
15
16 fn check(
17 &self,
18 username: impl Into<String> + Send,
19 password: impl Into<String> + Send,
20 ) -> impl Future<Output = Result<Self::CheckInfo>> + Send;
21
22 fn username(&self) -> &str {
23 unimplemented!()
24 }
25
26 fn password(&self) -> &str {
27 unimplemented!()
28 }
29}
30
31impl<T> AuthCheckPredicate for Arc<T>
32where
33 T: AuthCheckPredicate,
34{
35 type CheckInfo = T::CheckInfo;
36
37 fn check(
38 &self,
39 username: impl Into<String> + Send,
40 password: impl Into<String> + Send,
41 ) -> impl Future<Output = Result<Self::CheckInfo>> + Send {
42 self.as_ref().check(username, password)
43 }
44}
45
46#[derive(Clone)]
47pub struct AsyncBasicAuth<T>(T, String)
48where
49 T: AuthCheckPredicate + Clone + Send;
50
51impl<T> AsyncBasicAuth<T>
52where
53 T: AuthCheckPredicate + Clone + Send,
54{
55 pub fn new(p: T) -> Self {
56 Self(p, "Need basic authenticate".to_string())
57 }
58
59 pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
60 self.1 = msg.into();
61 self
62 }
63}
64
65impl<T> AsyncPredicate<Request> for AsyncBasicAuth<T>
66where
67 T: AuthCheckPredicate + Clone + Send + Sync + 'static,
68{
69 type Request = Request;
70 type Response = Response;
71 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
72
73 fn check(&self, mut request: Request) -> Self::Future {
74 let mut err = self.1.clone();
75 let auth = self.0.clone();
76 Box::pin(async move {
77 if let Some(authorization) = request.headers().typed_get::<Authorization<Basic>>() {
78 match auth
79 .check(authorization.username(), authorization.password())
80 .await
81 {
82 Err(e) => err = format!("check authorization error: {:?}", e),
83 Ok(ci) => {
84 request.extensions_mut().insert(ci);
85 return Ok(request);
86 }
87 }
88 }
89
90 drain_body(request).await;
91 Err((
92 StatusCode::UNAUTHORIZED,
93 [("WWW-Authenticate", "Basic"); 1],
94 err,
95 )
96 .into_response())
97 })
98 }
99}
100
101#[derive(Clone)]
102pub struct AsyncDigestAuth<T>
103where
104 T: AuthCheckPredicate + Clone + Send,
105{
106 inner: T,
107 err: String,
108 srv_name: String,
109 nonces: Arc<Mutex<VecDeque<(String, String)>>>,
110}
111
112impl<T> AsyncDigestAuth<T>
113where
114 T: AuthCheckPredicate + Clone + Send,
115{
116 pub fn new(p: T) -> Self {
117 Self {
118 inner: p,
119 srv_name: env!("CARGO_PKG_NAME").to_owned(),
120 err: "Need digest authenticate".to_string(),
121 nonces: Arc::new(Mutex::new(VecDeque::new())),
122 }
123 }
124
125 pub fn srv_name(mut self, name: impl Into<String>) -> Self {
126 self.srv_name = name.into();
127 self
128 }
129
130 pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
131 self.err = msg.into();
132 self
133 }
134}
135
136impl<T> AsyncPredicate<Request> for AsyncDigestAuth<T>
137where
138 T: AuthCheckPredicate + Clone + Send + Sync + 'static,
139{
140 type Request = Request;
141 type Response = Response;
142 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
143
144 fn check(&self, request: Request) -> Self::Future {
145 let err = self.err.clone();
146 let inner = self.inner.clone();
147 let srv_name = self.srv_name.clone();
148 let nonces = self.nonces.clone();
149 Box::pin(async move {
150 if let Some(auth_header) = request.headers().get("Authorization") {
151 let auth =
152 digest::Authorization::from_header(auth_header.to_str().map_err(bad_request)?)
153 .map_err(bad_request)?;
154
155 return auth.check(
156 inner.username(),
157 inner.password(),
158 nonces,
159 request,
160 srv_name,
161 );
162 }
163
164 drain_body(request).await;
165 Err(unauthorized(nonces, err, srv_name))
166 })
167 }
168}
169
170fn bad_request(e: impl Display) -> Response {
171 (
172 StatusCode::BAD_REQUEST,
173 format!("Bad request in header Authorization: {}", e),
174 )
175 .into_response()
176}
177
178mod digest {
179 use anyhow::{anyhow, bail, Result};
180 use axum::{
181 extract::Request,
182 http::StatusCode,
183 response::{IntoResponse, Response},
184 };
185 use parking_lot::Mutex;
186 use rand::{distributions::Alphanumeric, thread_rng, Rng};
187 use std::{collections::VecDeque, fmt::Debug, sync::Arc};
188
189 #[derive(Default, Debug)]
190 pub(super) struct Authorization {
191 pub(super) username: String,
192 pub(super) realm: String,
193 pub(super) nonce: String,
194 pub(super) uri: String,
195 pub(super) qop: String,
196 pub(super) nc: String,
197 pub(super) cnonce: String,
198 pub(super) response: String,
199 pub(super) opaque: String,
200 }
201
202 impl Authorization {
203 pub(super) fn check(
204 &self,
205 username: impl AsRef<str>,
206 password: impl AsRef<str>,
207 nonces: Arc<Mutex<VecDeque<(String, String)>>>,
208 request: Request,
209 srv_name: impl AsRef<str>,
210 ) -> Result<Request, Response> {
211 let mut found_nonce = false;
212 {
213 let mut nonce_list = nonces.lock();
214 let mut index = nonce_list.len().saturating_sub(1);
215
216 for (nonce, opaque) in nonce_list.iter().rev() {
217 if nonce == &self.nonce || opaque == &self.opaque {
218 found_nonce = true;
219 nonce_list.remove(index);
220 break;
221 }
222
223 index = index.saturating_sub(1);
224 }
225 }
226
227 if !found_nonce {
228 return Err(unauthorized(nonces, "invalid nonce or opaque", srv_name));
229 }
230
231 log::debug!("digest request: {:?}", request);
232 let ha1 = md5::compute(format!(
233 "{}:{}:{}",
234 username.as_ref(),
235 self.realm,
236 password.as_ref()
237 ));
238 let ha2 = md5::compute(format!("{}:{}", request.method(), self.uri));
239 let password = md5::compute(format!(
240 "{:x}:{}:{}:{}:{}:{:x}",
241 ha1, self.nonce, self.nc, self.cnonce, self.qop, ha2
242 ));
243
244 if format!("{:x}", password) != self.response {
245 return Err(unauthorized(
246 nonces,
247 "invalid username or password",
248 srv_name,
249 ));
250 }
251
252 Ok(request)
253 }
254
255 const DIGEST_MARK: &'static str = "Digest";
256 pub(super) fn from_header(auth: impl AsRef<str>) -> Result<Self> {
257 let auth = auth.as_ref();
258 let (mark, content) = auth.split_at(Self::DIGEST_MARK.len());
259 let content = content.trim();
260 if mark != Self::DIGEST_MARK {
261 bail!("only support digest authorization");
262 }
263
264 let mut result = Authorization::default();
265 for c in content.split(',') {
266 let c = c.trim();
267 let (k, v) = c
268 .split_once('=')
269 .ok_or_else(|| anyhow!("invalid part of authorization: {}", c))?;
270 let v = v.trim_matches('"');
271 match k {
272 "username" => result.username = v.to_string(),
273 "realm" => result.realm = v.to_string(),
274 "nonce" => result.nonce = v.to_string(),
275 "uri" => result.uri = v.to_string(),
276 "qop" => result.qop = v.to_string(),
277 "nc" => result.nc = v.to_string(),
278 "cnonce" => result.cnonce = v.to_string(),
279 "response" => result.response = v.to_string(),
280 "opaque" => result.opaque = v.to_string(),
281 _ => {
282 log::warn!("unknown authorization part: {}", c);
283 continue;
284 }
285 }
286 }
287
288 log::debug!("digest auth: {:?}", result);
289 Ok(result)
290 }
291 }
292
293 pub(super) fn unauthorized(
294 nonces: Arc<Mutex<VecDeque<(String, String)>>>,
295 msg: impl Into<String>,
296 srv_name: impl AsRef<str>,
297 ) -> Response {
298 let realm = format!("Login to {}", srv_name.as_ref());
299 let nonce = rand_string(32);
300 let opaque = rand_string(32);
301
302 let www_authenticate = format!(
303 r#"Digest realm="{}",qop="auth",nonce="{}",opaque="{}""#,
304 realm, nonce, opaque
305 );
306
307 {
308 let mut nonce_list = nonces.lock();
309 while nonce_list.len() >= 256 {
310 nonce_list.pop_front();
311 }
312
313 nonce_list.push_back((nonce, opaque));
314 }
315
316 (
317 StatusCode::UNAUTHORIZED,
318 [("WWW-Authenticate", www_authenticate); 1],
319 msg.into(),
320 )
321 .into_response()
322 }
323
324 fn rand_string(count: usize) -> String {
325 thread_rng()
326 .sample_iter(Alphanumeric)
327 .take(count)
328 .map(char::from)
329 .collect()
330 }
331}