rustic_server/
auth.rs

1use std::{borrow::Borrow, path::PathBuf};
2
3use abscissa_core::SecretString;
4use axum::{extract::FromRequestParts, http::request::Parts};
5use axum_auth::AuthBasic;
6use serde_derive::Deserialize;
7use std::sync::OnceLock;
8
9use crate::{
10    config::HtpasswdSettings,
11    error::{ApiErrorKind, ApiResult, AppResult},
12    htpasswd::{CredentialMap, Htpasswd},
13};
14
15// Static storage of our credentials
16pub static AUTH: OnceLock<Auth> = OnceLock::new();
17
18pub(crate) fn init_auth(auth: Auth) -> AppResult<()> {
19    let _ = AUTH.get_or_init(|| auth);
20    Ok(())
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct Auth {
25    users: Option<CredentialMap>,
26}
27
28impl From<CredentialMap> for Auth {
29    fn from(users: CredentialMap) -> Self {
30        Self { users: Some(users) }
31    }
32}
33
34impl From<Htpasswd> for Auth {
35    fn from(htpasswd: Htpasswd) -> Self {
36        Self {
37            users: Some(htpasswd.credentials),
38        }
39    }
40}
41
42impl Auth {
43    pub fn from_file(disable_auth: bool, path: &PathBuf) -> AppResult<Self> {
44        Ok(if disable_auth {
45            Self::default()
46        } else {
47            Htpasswd::from_file(path)?.into()
48        })
49    }
50
51    pub fn from_config(settings: &HtpasswdSettings, path: PathBuf) -> AppResult<Self> {
52        Self::from_file(settings.is_disabled(), &path)
53    }
54
55    // verify verifies user/passwd against the credentials saved in users.
56    // returns true if Auth::users is None.
57    pub fn verify(&self, user: impl Into<String>, passwd: impl Into<String>) -> bool {
58        let user = user.into();
59        let passwd = passwd.into();
60
61        self.users.as_ref().map_or(true, |users| matches!(users.get(&user), Some(passwd_data) if htpasswd_verify::Htpasswd::from(passwd_data.to_string().borrow()).check(user, passwd)))
62    }
63
64    pub const fn is_disabled(&self) -> bool {
65        self.users.is_none()
66    }
67}
68
69#[derive(Deserialize, Debug)]
70pub struct BasicAuthFromRequest {
71    pub(crate) user: String,
72    pub(crate) _password: SecretString,
73}
74
75#[async_trait::async_trait]
76impl<S: Send + Sync> FromRequestParts<S> for BasicAuthFromRequest {
77    type Rejection = ApiErrorKind;
78
79    // FIXME: We also have a configuration flag do run without authentication
80    // This must be handled here too ... otherwise we get an Auth header missing error.
81    async fn from_request_parts(parts: &mut Parts, state: &S) -> ApiResult<Self> {
82        let checker = AUTH.get().unwrap();
83
84        let auth_result = AuthBasic::from_request_parts(parts, state).await;
85
86        tracing::debug!(?auth_result, "[AUTH]");
87
88        return match auth_result {
89            Ok(auth) => {
90                let AuthBasic((user, passw)) = auth;
91                let password = passw.unwrap_or_else(String::new);
92                if checker.verify(user.as_str(), password.as_str()) {
93                    Ok(Self {
94                        user,
95                        _password: password.into(),
96                    })
97                } else {
98                    Err(ApiErrorKind::UserAuthenticationError(user))
99                }
100            }
101            Err(_) => {
102                let user = String::new();
103                if checker.verify("", "") {
104                    return Ok(Self {
105                        user,
106                        _password: String::new().into(),
107                    });
108                }
109                Err(ApiErrorKind::AuthenticationHeaderError)
110            }
111        };
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use super::*;
118
119    use crate::testing::{basic_auth_header_value, init_test_environment, server_config};
120
121    use anyhow::Result;
122    use axum::{
123        body::Body,
124        http::{Method, Request, StatusCode},
125        routing::get,
126        Router,
127    };
128    use http_body_util::BodyExt;
129    use rstest::{fixture, rstest};
130    use tower::ServiceExt;
131
132    #[fixture]
133    fn auth() -> Auth {
134        let htpasswd = PathBuf::from("tests/fixtures/test_data/.htpasswd");
135        Auth::from_file(false, &htpasswd).unwrap()
136    }
137
138    #[rstest]
139    fn test_auth_passes(auth: Auth) -> Result<()> {
140        assert!(auth.verify("rustic", "rustic"));
141        assert!(!auth.verify("rustic", "_rustic"));
142
143        Ok(())
144    }
145
146    #[rstest]
147    fn test_auth_from_file_passes(auth: Auth) {
148        init_auth(auth).unwrap();
149
150        let auth = AUTH.get().unwrap();
151        assert!(auth.verify("rustic", "rustic"));
152        assert!(!auth.verify("rustic", "_rustic"));
153    }
154
155    async fn format_auth_basic(AuthBasic((id, password)): AuthBasic) -> String {
156        format!("Got {} and {:?}", id, password)
157    }
158
159    async fn format_handler_from_auth_request(auth: BasicAuthFromRequest) -> String {
160        format!("User = {}", auth.user)
161    }
162
163    /// The requests which should be returned OK
164    #[tokio::test]
165    async fn test_authentication_passes() {
166        init_test_environment(server_config());
167
168        // -----------------------------------------
169        // Try good basic
170        // -----------------------------------------
171        let app = Router::new().route("/basic", get(format_auth_basic));
172
173        let request = Request::builder()
174            .uri("/basic")
175            .method(Method::GET)
176            .header(
177                "Authorization",
178                basic_auth_header_value("My Username", Some("My Password")),
179            )
180            .body(Body::empty())
181            .unwrap();
182
183        let resp = app.oneshot(request).await.unwrap();
184
185        assert_eq!(resp.status(), StatusCode::OK);
186        let body = resp.into_parts().1;
187        let byte_vec = body.into_data_stream().collect().await.unwrap().to_bytes();
188        let body_str = String::from_utf8(byte_vec.to_vec()).unwrap();
189        assert_eq!(
190            body_str,
191            String::from("Got My Username and Some(\"My Password\")")
192        );
193
194        // -----------------------------------------
195        // Try good using auth struct
196        // -----------------------------------------
197        let app = Router::new().route("/rustic_server", get(format_handler_from_auth_request));
198
199        let request = Request::builder()
200            .uri("/rustic_server")
201            .method(Method::GET)
202            .header(
203                "Authorization",
204                basic_auth_header_value("rustic", Some("rustic")),
205            )
206            .body(Body::empty())
207            .unwrap();
208
209        let resp = app.oneshot(request).await.unwrap();
210
211        assert_eq!(resp.status().as_u16(), StatusCode::OK.as_u16());
212        let body = resp.into_parts().1;
213        let byte_vec = body.collect().await.unwrap().to_bytes();
214        let body_str = String::from_utf8(byte_vec.to_vec()).unwrap();
215        assert_eq!(body_str, String::from("User = rustic"));
216    }
217
218    #[tokio::test]
219    async fn test_fail_authentication_passes() {
220        init_test_environment(server_config());
221
222        // -----------------------------------------
223        // Try wrong password rustic_server
224        // -----------------------------------------
225        let app = Router::new().route("/rustic_server", get(format_handler_from_auth_request));
226
227        let request = Request::builder()
228            .uri("/rustic_server")
229            .method(Method::GET)
230            .header(
231                "Authorization",
232                basic_auth_header_value("rustic", Some("_rustic")),
233            )
234            .body(Body::empty())
235            .unwrap();
236
237        let resp = app.oneshot(request).await.unwrap();
238
239        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
240
241        // -----------------------------------------
242        // Try without authentication header
243        // -----------------------------------------
244        let app = Router::new().route("/rustic_server", get(format_handler_from_auth_request));
245
246        let request = Request::builder()
247            .uri("/rustic_server")
248            .method(Method::GET)
249            .body(Body::empty())
250            .unwrap();
251
252        let resp = app.oneshot(request).await.unwrap();
253
254        assert_eq!(resp.status().as_u16(), StatusCode::FORBIDDEN);
255    }
256}