1use std::{borrow::Borrow, path::PathBuf};
2
3use abscissa_core::SecretString;
4use axum::{extract::FromRequestParts, http::request::Parts};
5use axum_auth::AuthBasic;
6use serde_derive::Deserialize;
7use std::sync::OnceLock;
8
9use crate::{
10 config::HtpasswdSettings,
11 error::{ApiErrorKind, ApiResult, AppResult},
12 htpasswd::{CredentialMap, Htpasswd},
13};
14
15pub static AUTH: OnceLock<Auth> = OnceLock::new();
17
18pub(crate) fn init_auth(auth: Auth) -> AppResult<()> {
19 let _ = AUTH.get_or_init(|| auth);
20 Ok(())
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct Auth {
25 users: Option<CredentialMap>,
26}
27
28impl From<CredentialMap> for Auth {
29 fn from(users: CredentialMap) -> Self {
30 Self { users: Some(users) }
31 }
32}
33
34impl From<Htpasswd> for Auth {
35 fn from(htpasswd: Htpasswd) -> Self {
36 Self {
37 users: Some(htpasswd.credentials),
38 }
39 }
40}
41
42impl Auth {
43 pub fn from_file(disable_auth: bool, path: &PathBuf) -> AppResult<Self> {
44 Ok(if disable_auth {
45 Self::default()
46 } else {
47 Htpasswd::from_file(path)?.into()
48 })
49 }
50
51 pub fn from_config(settings: &HtpasswdSettings, path: PathBuf) -> AppResult<Self> {
52 Self::from_file(settings.is_disabled(), &path)
53 }
54
55 pub fn verify(&self, user: impl Into<String>, passwd: impl Into<String>) -> bool {
58 let user = user.into();
59 let passwd = passwd.into();
60
61 self.users.as_ref().map_or(true, |users| matches!(users.get(&user), Some(passwd_data) if htpasswd_verify::Htpasswd::from(passwd_data.to_string().borrow()).check(user, passwd)))
62 }
63
64 pub const fn is_disabled(&self) -> bool {
65 self.users.is_none()
66 }
67}
68
69#[derive(Deserialize, Debug)]
70pub struct BasicAuthFromRequest {
71 pub(crate) user: String,
72 pub(crate) _password: SecretString,
73}
74
75#[async_trait::async_trait]
76impl<S: Send + Sync> FromRequestParts<S> for BasicAuthFromRequest {
77 type Rejection = ApiErrorKind;
78
79 async fn from_request_parts(parts: &mut Parts, state: &S) -> ApiResult<Self> {
82 let checker = AUTH.get().unwrap();
83
84 let auth_result = AuthBasic::from_request_parts(parts, state).await;
85
86 tracing::debug!(?auth_result, "[AUTH]");
87
88 return match auth_result {
89 Ok(auth) => {
90 let AuthBasic((user, passw)) = auth;
91 let password = passw.unwrap_or_else(String::new);
92 if checker.verify(user.as_str(), password.as_str()) {
93 Ok(Self {
94 user,
95 _password: password.into(),
96 })
97 } else {
98 Err(ApiErrorKind::UserAuthenticationError(user))
99 }
100 }
101 Err(_) => {
102 let user = String::new();
103 if checker.verify("", "") {
104 return Ok(Self {
105 user,
106 _password: String::new().into(),
107 });
108 }
109 Err(ApiErrorKind::AuthenticationHeaderError)
110 }
111 };
112 }
113}
114
115#[cfg(test)]
116mod test {
117 use super::*;
118
119 use crate::testing::{basic_auth_header_value, init_test_environment, server_config};
120
121 use anyhow::Result;
122 use axum::{
123 body::Body,
124 http::{Method, Request, StatusCode},
125 routing::get,
126 Router,
127 };
128 use http_body_util::BodyExt;
129 use rstest::{fixture, rstest};
130 use tower::ServiceExt;
131
132 #[fixture]
133 fn auth() -> Auth {
134 let htpasswd = PathBuf::from("tests/fixtures/test_data/.htpasswd");
135 Auth::from_file(false, &htpasswd).unwrap()
136 }
137
138 #[rstest]
139 fn test_auth_passes(auth: Auth) -> Result<()> {
140 assert!(auth.verify("rustic", "rustic"));
141 assert!(!auth.verify("rustic", "_rustic"));
142
143 Ok(())
144 }
145
146 #[rstest]
147 fn test_auth_from_file_passes(auth: Auth) {
148 init_auth(auth).unwrap();
149
150 let auth = AUTH.get().unwrap();
151 assert!(auth.verify("rustic", "rustic"));
152 assert!(!auth.verify("rustic", "_rustic"));
153 }
154
155 async fn format_auth_basic(AuthBasic((id, password)): AuthBasic) -> String {
156 format!("Got {} and {:?}", id, password)
157 }
158
159 async fn format_handler_from_auth_request(auth: BasicAuthFromRequest) -> String {
160 format!("User = {}", auth.user)
161 }
162
163 #[tokio::test]
165 async fn test_authentication_passes() {
166 init_test_environment(server_config());
167
168 let app = Router::new().route("/basic", get(format_auth_basic));
172
173 let request = Request::builder()
174 .uri("/basic")
175 .method(Method::GET)
176 .header(
177 "Authorization",
178 basic_auth_header_value("My Username", Some("My Password")),
179 )
180 .body(Body::empty())
181 .unwrap();
182
183 let resp = app.oneshot(request).await.unwrap();
184
185 assert_eq!(resp.status(), StatusCode::OK);
186 let body = resp.into_parts().1;
187 let byte_vec = body.into_data_stream().collect().await.unwrap().to_bytes();
188 let body_str = String::from_utf8(byte_vec.to_vec()).unwrap();
189 assert_eq!(
190 body_str,
191 String::from("Got My Username and Some(\"My Password\")")
192 );
193
194 let app = Router::new().route("/rustic_server", get(format_handler_from_auth_request));
198
199 let request = Request::builder()
200 .uri("/rustic_server")
201 .method(Method::GET)
202 .header(
203 "Authorization",
204 basic_auth_header_value("rustic", Some("rustic")),
205 )
206 .body(Body::empty())
207 .unwrap();
208
209 let resp = app.oneshot(request).await.unwrap();
210
211 assert_eq!(resp.status().as_u16(), StatusCode::OK.as_u16());
212 let body = resp.into_parts().1;
213 let byte_vec = body.collect().await.unwrap().to_bytes();
214 let body_str = String::from_utf8(byte_vec.to_vec()).unwrap();
215 assert_eq!(body_str, String::from("User = rustic"));
216 }
217
218 #[tokio::test]
219 async fn test_fail_authentication_passes() {
220 init_test_environment(server_config());
221
222 let app = Router::new().route("/rustic_server", get(format_handler_from_auth_request));
226
227 let request = Request::builder()
228 .uri("/rustic_server")
229 .method(Method::GET)
230 .header(
231 "Authorization",
232 basic_auth_header_value("rustic", Some("_rustic")),
233 )
234 .body(Body::empty())
235 .unwrap();
236
237 let resp = app.oneshot(request).await.unwrap();
238
239 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
240
241 let app = Router::new().route("/rustic_server", get(format_handler_from_auth_request));
245
246 let request = Request::builder()
247 .uri("/rustic_server")
248 .method(Method::GET)
249 .body(Body::empty())
250 .unwrap();
251
252 let resp = app.oneshot(request).await.unwrap();
253
254 assert_eq!(resp.status().as_u16(), StatusCode::FORBIDDEN);
255 }
256}