1use axum::{
2 extract::State,
3 http::{Request, StatusCode, header::AUTHORIZATION},
4 middleware::Next,
5 response::{IntoResponse, Response},
6};
7use serde::Serialize;
8
9#[derive(Debug, Clone)]
10pub struct AuthConfig {
11 pub auth_token: Option<String>,
12 pub no_auth: bool,
13}
14
15impl AuthConfig {
16 pub fn disabled() -> Self {
17 Self {
18 auth_token: None,
19 no_auth: true,
20 }
21 }
22
23 pub fn token(auth_token: impl Into<String>) -> Self {
24 Self {
25 auth_token: Some(auth_token.into()),
26 no_auth: false,
27 }
28 }
29
30 fn should_bypass(&self) -> bool {
31 self.no_auth || self.auth_token.is_none()
32 }
33
34 fn is_authorized(&self, request: &Request<axum::body::Body>) -> bool {
35 if self.should_bypass() {
36 return true;
37 }
38
39 let Some(expected_token) = self.auth_token.as_deref() else {
40 return true;
41 };
42
43 request
44 .headers()
45 .get(AUTHORIZATION)
46 .and_then(|value| value.to_str().ok())
47 .and_then(|value| value.strip_prefix("Bearer "))
48 .is_some_and(|provided| provided == expected_token)
49 }
50}
51
52#[derive(Debug, Serialize)]
53struct AuthErrorBody {
54 error: String,
55 code: String,
56}
57
58pub async fn require_bearer(
59 State(config): State<AuthConfig>,
60 request: Request<axum::body::Body>,
61 next: Next,
62) -> Response {
63 if config.is_authorized(&request) {
64 return next.run(request).await;
65 }
66
67 let body = AuthErrorBody {
68 error: "Unauthorized".to_string(),
69 code: "unauthorized".to_string(),
70 };
71
72 (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
73}