spa_rs/
auth.rs

1use self::digest::unauthorized;
2use anyhow::Result;
3use axum::{
4    extract::Request,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use axum_help::filter::{drain_body, AsyncPredicate};
9use headers::{authorization::Basic, Authorization, HeaderMapExt};
10use parking_lot::Mutex;
11use std::{collections::VecDeque, fmt::Display, future::Future, pin::Pin, sync::Arc};
12
13pub trait AuthCheckPredicate {
14    type CheckInfo: Clone + Send + Sync + 'static;
15
16    fn check(
17        &self,
18        username: impl Into<String> + Send,
19        password: impl Into<String> + Send,
20    ) -> impl Future<Output = Result<Self::CheckInfo>> + Send;
21
22    fn username(&self) -> &str {
23        unimplemented!()
24    }
25
26    fn password(&self) -> &str {
27        unimplemented!()
28    }
29}
30
31impl<T> AuthCheckPredicate for Arc<T>
32where
33    T: AuthCheckPredicate,
34{
35    type CheckInfo = T::CheckInfo;
36
37    fn check(
38        &self,
39        username: impl Into<String> + Send,
40        password: impl Into<String> + Send,
41    ) -> impl Future<Output = Result<Self::CheckInfo>> + Send {
42        self.as_ref().check(username, password)
43    }
44}
45
46#[derive(Clone)]
47pub struct AsyncBasicAuth<T>(T, String)
48where
49    T: AuthCheckPredicate + Clone + Send;
50
51impl<T> AsyncBasicAuth<T>
52where
53    T: AuthCheckPredicate + Clone + Send,
54{
55    pub fn new(p: T) -> Self {
56        Self(p, "Need basic authenticate".to_string())
57    }
58
59    pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
60        self.1 = msg.into();
61        self
62    }
63}
64
65impl<T> AsyncPredicate<Request> for AsyncBasicAuth<T>
66where
67    T: AuthCheckPredicate + Clone + Send + Sync + 'static,
68{
69    type Request = Request;
70    type Response = Response;
71    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
72
73    fn check(&self, mut request: Request) -> Self::Future {
74        let mut err = self.1.clone();
75        let auth = self.0.clone();
76        Box::pin(async move {
77            if let Some(authorization) = request.headers().typed_get::<Authorization<Basic>>() {
78                match auth
79                    .check(authorization.username(), authorization.password())
80                    .await
81                {
82                    Err(e) => err = format!("check authorization error: {:?}", e),
83                    Ok(ci) => {
84                        request.extensions_mut().insert(ci);
85                        return Ok(request);
86                    }
87                }
88            }
89
90            drain_body(request).await;
91            Err((
92                StatusCode::UNAUTHORIZED,
93                [("WWW-Authenticate", "Basic"); 1],
94                err,
95            )
96                .into_response())
97        })
98    }
99}
100
101#[derive(Clone)]
102pub struct AsyncDigestAuth<T>
103where
104    T: AuthCheckPredicate + Clone + Send,
105{
106    inner: T,
107    err: String,
108    srv_name: String,
109    nonces: Arc<Mutex<VecDeque<(String, String)>>>,
110}
111
112impl<T> AsyncDigestAuth<T>
113where
114    T: AuthCheckPredicate + Clone + Send,
115{
116    pub fn new(p: T) -> Self {
117        Self {
118            inner: p,
119            srv_name: env!("CARGO_PKG_NAME").to_owned(),
120            err: "Need digest authenticate".to_string(),
121            nonces: Arc::new(Mutex::new(VecDeque::new())),
122        }
123    }
124
125    pub fn srv_name(mut self, name: impl Into<String>) -> Self {
126        self.srv_name = name.into();
127        self
128    }
129
130    pub fn err_msg(mut self, msg: impl Into<String>) -> Self {
131        self.err = msg.into();
132        self
133    }
134}
135
136impl<T> AsyncPredicate<Request> for AsyncDigestAuth<T>
137where
138    T: AuthCheckPredicate + Clone + Send + Sync + 'static,
139{
140    type Request = Request;
141    type Response = Response;
142    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, Self::Response>> + Send>>;
143
144    fn check(&self, request: Request) -> Self::Future {
145        let err = self.err.clone();
146        let inner = self.inner.clone();
147        let srv_name = self.srv_name.clone();
148        let nonces = self.nonces.clone();
149        Box::pin(async move {
150            if let Some(auth_header) = request.headers().get("Authorization") {
151                let auth =
152                    digest::Authorization::from_header(auth_header.to_str().map_err(bad_request)?)
153                        .map_err(bad_request)?;
154
155                return auth.check(
156                    inner.username(),
157                    inner.password(),
158                    nonces,
159                    request,
160                    srv_name,
161                );
162            }
163
164            drain_body(request).await;
165            Err(unauthorized(nonces, err, srv_name))
166        })
167    }
168}
169
170fn bad_request(e: impl Display) -> Response {
171    (
172        StatusCode::BAD_REQUEST,
173        format!("Bad request in header Authorization: {}", e),
174    )
175        .into_response()
176}
177
178mod digest {
179    use anyhow::{anyhow, bail, Result};
180    use axum::{
181        extract::Request,
182        http::StatusCode,
183        response::{IntoResponse, Response},
184    };
185    use parking_lot::Mutex;
186    use rand::{distributions::Alphanumeric, thread_rng, Rng};
187    use std::{collections::VecDeque, fmt::Debug, sync::Arc};
188
189    #[derive(Default, Debug)]
190    pub(super) struct Authorization {
191        pub(super) username: String,
192        pub(super) realm: String,
193        pub(super) nonce: String,
194        pub(super) uri: String,
195        pub(super) qop: String,
196        pub(super) nc: String,
197        pub(super) cnonce: String,
198        pub(super) response: String,
199        pub(super) opaque: String,
200    }
201
202    impl Authorization {
203        pub(super) fn check(
204            &self,
205            username: impl AsRef<str>,
206            password: impl AsRef<str>,
207            nonces: Arc<Mutex<VecDeque<(String, String)>>>,
208            request: Request,
209            srv_name: impl AsRef<str>,
210        ) -> Result<Request, Response> {
211            let mut found_nonce = false;
212            {
213                let mut nonce_list = nonces.lock();
214                let mut index = nonce_list.len().saturating_sub(1);
215
216                for (nonce, opaque) in nonce_list.iter().rev() {
217                    if nonce == &self.nonce || opaque == &self.opaque {
218                        found_nonce = true;
219                        nonce_list.remove(index);
220                        break;
221                    }
222
223                    index = index.saturating_sub(1);
224                }
225            }
226
227            if !found_nonce {
228                return Err(unauthorized(nonces, "invalid nonce or opaque", srv_name));
229            }
230
231            log::debug!("digest request: {:?}", request);
232            let ha1 = md5::compute(format!(
233                "{}:{}:{}",
234                username.as_ref(),
235                self.realm,
236                password.as_ref()
237            ));
238            let ha2 = md5::compute(format!("{}:{}", request.method(), self.uri));
239            let password = md5::compute(format!(
240                "{:x}:{}:{}:{}:{}:{:x}",
241                ha1, self.nonce, self.nc, self.cnonce, self.qop, ha2
242            ));
243
244            if format!("{:x}", password) != self.response {
245                return Err(unauthorized(
246                    nonces,
247                    "invalid username or password",
248                    srv_name,
249                ));
250            }
251
252            Ok(request)
253        }
254
255        const DIGEST_MARK: &'static str = "Digest";
256        pub(super) fn from_header(auth: impl AsRef<str>) -> Result<Self> {
257            let auth = auth.as_ref();
258            let (mark, content) = auth.split_at(Self::DIGEST_MARK.len());
259            let content = content.trim();
260            if mark != Self::DIGEST_MARK {
261                bail!("only support digest authorization");
262            }
263
264            let mut result = Authorization::default();
265            for c in content.split(',') {
266                let c = c.trim();
267                let (k, v) = c
268                    .split_once('=')
269                    .ok_or_else(|| anyhow!("invalid part of authorization: {}", c))?;
270                let v = v.trim_matches('"');
271                match k {
272                    "username" => result.username = v.to_string(),
273                    "realm" => result.realm = v.to_string(),
274                    "nonce" => result.nonce = v.to_string(),
275                    "uri" => result.uri = v.to_string(),
276                    "qop" => result.qop = v.to_string(),
277                    "nc" => result.nc = v.to_string(),
278                    "cnonce" => result.cnonce = v.to_string(),
279                    "response" => result.response = v.to_string(),
280                    "opaque" => result.opaque = v.to_string(),
281                    _ => {
282                        log::warn!("unknown authorization part: {}", c);
283                        continue;
284                    }
285                }
286            }
287
288            log::debug!("digest auth: {:?}", result);
289            Ok(result)
290        }
291    }
292
293    pub(super) fn unauthorized(
294        nonces: Arc<Mutex<VecDeque<(String, String)>>>,
295        msg: impl Into<String>,
296        srv_name: impl AsRef<str>,
297    ) -> Response {
298        let realm = format!("Login to {}", srv_name.as_ref());
299        let nonce = rand_string(32);
300        let opaque = rand_string(32);
301
302        let www_authenticate = format!(
303            r#"Digest realm="{}",qop="auth",nonce="{}",opaque="{}""#,
304            realm, nonce, opaque
305        );
306
307        {
308            let mut nonce_list = nonces.lock();
309            while nonce_list.len() >= 256 {
310                nonce_list.pop_front();
311            }
312
313            nonce_list.push_back((nonce, opaque));
314        }
315
316        (
317            StatusCode::UNAUTHORIZED,
318            [("WWW-Authenticate", www_authenticate); 1],
319            msg.into(),
320        )
321            .into_response()
322    }
323
324    fn rand_string(count: usize) -> String {
325        thread_rng()
326            .sample_iter(Alphanumeric)
327            .take(count)
328            .map(char::from)
329            .collect()
330    }
331}