1use crate::types::{derive_account_id, JmapError, JmapErrorType, Principal};
21use axum::{
22 extract::{Request, State},
23 http::{header, HeaderMap, StatusCode},
24 middleware::Next,
25 response::{IntoResponse, Response},
26 Json,
27};
28use base64::{engine::general_purpose, Engine as _};
29use rusmes_auth::AuthBackend;
30use rusmes_proto::Username;
31use std::sync::Arc;
32
33pub type SharedAuth = Arc<dyn AuthBackend>;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum Credentials {
41 Basic { username: String, password: String },
44 Bearer { token: String },
47}
48
49pub fn extract_credentials(headers: &HeaderMap) -> Option<Credentials> {
55 let value = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
56 let trimmed = value.trim();
57
58 if let Some(rest) = strip_scheme(trimmed, "Basic") {
59 let decoded_bytes = general_purpose::STANDARD.decode(rest).ok()?;
60 let decoded = String::from_utf8(decoded_bytes).ok()?;
61 let mut parts = decoded.splitn(2, ':');
62 let username = parts.next()?.to_string();
63 let password = parts.next()?.to_string();
64 if username.is_empty() {
65 return None;
66 }
67 return Some(Credentials::Basic { username, password });
68 }
69
70 if let Some(rest) = strip_scheme(trimmed, "Bearer") {
71 let token = rest.trim().to_string();
72 if token.is_empty() {
73 return None;
74 }
75 return Some(Credentials::Bearer { token });
76 }
77
78 None
79}
80
81fn strip_scheme<'a>(header_value: &'a str, scheme: &str) -> Option<&'a str> {
82 let scheme_len = scheme.len();
83 if header_value.len() <= scheme_len {
84 return None;
85 }
86 let (prefix, rest) = header_value.split_at(scheme_len);
87 if !prefix.eq_ignore_ascii_case(scheme) {
88 return None;
89 }
90 let rest = rest.trim_start();
91 if rest.is_empty() {
92 return None;
93 }
94 Some(rest)
95}
96
97pub async fn authenticate(
105 auth: &dyn AuthBackend,
106 creds: &Credentials,
107) -> Result<Principal, AuthError> {
108 match creds {
109 Credentials::Basic { username, password } => {
110 let user = Username::new(username.clone()).map_err(|_| AuthError::Unauthorized)?;
111 let ok = auth
112 .authenticate(&user, password)
113 .await
114 .map_err(|err| AuthError::Backend(err.to_string()))?;
115 if !ok {
116 return Err(AuthError::Unauthorized);
117 }
118 Ok(Principal {
119 username: username.clone(),
120 account_id: derive_account_id(username),
121 scopes: Vec::new(),
122 })
123 }
124 Credentials::Bearer { token } => {
125 let username = auth
126 .verify_bearer_token(token)
127 .await
128 .map_err(|_| AuthError::Unauthorized)?;
129 let username_str = username.to_string();
130 Ok(Principal {
131 account_id: derive_account_id(&username_str),
132 username: username_str,
133 scopes: Vec::new(),
134 })
135 }
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq)]
141pub enum AuthError {
142 Unauthorized,
144 Backend(String),
147}
148
149impl AuthError {
150 fn into_response_body(self) -> Response {
151 let detail = match self {
152 AuthError::Unauthorized => "Authentication required".to_string(),
153 AuthError::Backend(err) => {
154 tracing::warn!("JMAP auth backend error: {}", err);
155 "Authentication backend error".to_string()
156 }
157 };
158 let body = JmapError::new(JmapErrorType::ServerFail)
159 .with_status(401)
160 .with_detail(detail);
161 let mut resp = (StatusCode::UNAUTHORIZED, Json(body)).into_response();
162 if let Ok(value) = header::HeaderValue::from_str("Basic realm=\"jmap\"") {
164 resp.headers_mut().insert(header::WWW_AUTHENTICATE, value);
165 }
166 resp
167 }
168}
169
170pub async fn require_auth(
174 State(auth): State<SharedAuth>,
175 mut request: Request,
176 next: Next,
177) -> Response {
178 let creds = match extract_credentials(request.headers()) {
179 Some(c) => c,
180 None => return AuthError::Unauthorized.into_response_body(),
181 };
182 let principal = match authenticate(auth.as_ref(), &creds).await {
183 Ok(p) => p,
184 Err(err) => return err.into_response_body(),
185 };
186 request.extensions_mut().insert(principal);
187 next.run(request).await
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use async_trait::async_trait;
194 use axum::http::HeaderValue;
195
196 struct TestBackend;
198
199 #[async_trait]
200 impl AuthBackend for TestBackend {
201 async fn authenticate(&self, username: &Username, password: &str) -> anyhow::Result<bool> {
202 Ok(username.as_str() == "alice" && password == "hunter2")
203 }
204 async fn verify_identity(&self, _username: &Username) -> anyhow::Result<bool> {
205 Ok(true)
206 }
207 async fn list_users(&self) -> anyhow::Result<Vec<Username>> {
208 Ok(vec![])
209 }
210 async fn create_user(&self, _u: &Username, _p: &str) -> anyhow::Result<()> {
211 Ok(())
212 }
213 async fn delete_user(&self, _u: &Username) -> anyhow::Result<()> {
214 Ok(())
215 }
216 async fn change_password(&self, _u: &Username, _p: &str) -> anyhow::Result<()> {
217 Ok(())
218 }
219 }
220
221 fn header_with_auth(value: &str) -> HeaderMap {
222 let mut headers = HeaderMap::new();
223 if let Ok(v) = HeaderValue::from_str(value) {
224 headers.insert(header::AUTHORIZATION, v);
225 }
226 headers
227 }
228
229 #[test]
230 fn test_extract_basic_ok() {
231 let headers = header_with_auth("Basic YWxpY2U6aHVudGVyMg==");
233 let creds = extract_credentials(&headers).expect("creds parse");
234 assert_eq!(
235 creds,
236 Credentials::Basic {
237 username: "alice".to_string(),
238 password: "hunter2".to_string()
239 }
240 );
241 }
242
243 #[test]
244 fn test_extract_basic_case_insensitive_scheme() {
245 let headers = header_with_auth("basic YWxpY2U6aHVudGVyMg==");
246 assert!(extract_credentials(&headers).is_some());
247 }
248
249 #[test]
250 fn test_extract_bearer_ok() {
251 let headers = header_with_auth("Bearer abc.def.ghi");
252 let creds = extract_credentials(&headers).expect("creds parse");
253 assert_eq!(
254 creds,
255 Credentials::Bearer {
256 token: "abc.def.ghi".to_string()
257 }
258 );
259 }
260
261 #[test]
262 fn test_extract_no_header() {
263 let headers = HeaderMap::new();
264 assert!(extract_credentials(&headers).is_none());
265 }
266
267 #[test]
268 fn test_extract_unknown_scheme() {
269 let headers = header_with_auth("Digest something");
270 assert!(extract_credentials(&headers).is_none());
271 }
272
273 #[test]
274 fn test_extract_basic_empty_username_rejected() {
275 let headers = header_with_auth("Basic OnB3ZA==");
277 assert!(extract_credentials(&headers).is_none());
278 }
279
280 #[test]
281 fn test_extract_basic_no_colon_rejected() {
282 let headers = header_with_auth("Basic YWxpY2VodW50ZXIy");
284 assert!(extract_credentials(&headers).is_none());
285 }
286
287 #[tokio::test]
288 async fn test_authenticate_basic_ok() {
289 let backend = TestBackend;
290 let creds = Credentials::Basic {
291 username: "alice".to_string(),
292 password: "hunter2".to_string(),
293 };
294 let principal = authenticate(&backend, &creds).await.expect("auth ok");
295 assert_eq!(principal.username, "alice");
296 assert_eq!(principal.account_id, "account-alice");
297 }
298
299 #[tokio::test]
300 async fn test_authenticate_basic_bad_password() {
301 let backend = TestBackend;
302 let creds = Credentials::Basic {
303 username: "alice".to_string(),
304 password: "wrong".to_string(),
305 };
306 let err = authenticate(&backend, &creds)
307 .await
308 .expect_err("should fail");
309 assert_eq!(err, AuthError::Unauthorized);
310 }
311
312 #[tokio::test]
313 async fn test_authenticate_bearer_backend_without_override_rejected() {
314 let backend = TestBackend;
317 let creds = Credentials::Bearer {
318 token: "anything".to_string(),
319 };
320 let err = authenticate(&backend, &creds)
321 .await
322 .expect_err("bearer 401");
323 assert_eq!(err, AuthError::Unauthorized);
324 }
325
326 #[tokio::test]
327 async fn test_authenticate_basic_with_email_username() {
328 let backend = TestBackend;
330 let creds = Credentials::Basic {
331 username: "bob@example.com".to_string(),
332 password: "hunter2".to_string(),
333 };
334 let err = authenticate(&backend, &creds).await.expect_err("rejected");
336 assert_eq!(err, AuthError::Unauthorized);
337 }
338}