Skip to main content

stakpak_server/
auth.rs

1use axum::{
2    extract::State,
3    http::{Request, StatusCode, header::AUTHORIZATION},
4    middleware::Next,
5    response::{IntoResponse, Response},
6};
7use serde::Serialize;
8
9#[derive(Debug, Clone)]
10pub struct AuthConfig {
11    pub auth_token: Option<String>,
12    pub no_auth: bool,
13}
14
15impl AuthConfig {
16    pub fn disabled() -> Self {
17        Self {
18            auth_token: None,
19            no_auth: true,
20        }
21    }
22
23    pub fn token(auth_token: impl Into<String>) -> Self {
24        Self {
25            auth_token: Some(auth_token.into()),
26            no_auth: false,
27        }
28    }
29
30    fn should_bypass(&self) -> bool {
31        self.no_auth || self.auth_token.is_none()
32    }
33
34    fn is_authorized(&self, request: &Request<axum::body::Body>) -> bool {
35        if self.should_bypass() {
36            return true;
37        }
38
39        let Some(expected_token) = self.auth_token.as_deref() else {
40            return true;
41        };
42
43        request
44            .headers()
45            .get(AUTHORIZATION)
46            .and_then(|value| value.to_str().ok())
47            .and_then(|value| value.strip_prefix("Bearer "))
48            .is_some_and(|provided| provided == expected_token)
49    }
50}
51
52#[derive(Debug, Serialize)]
53struct AuthErrorBody {
54    error: String,
55    code: String,
56}
57
58pub async fn require_bearer(
59    State(config): State<AuthConfig>,
60    request: Request<axum::body::Body>,
61    next: Next,
62) -> Response {
63    if config.is_authorized(&request) {
64        return next.run(request).await;
65    }
66
67    let body = AuthErrorBody {
68        error: "Unauthorized".to_string(),
69        code: "unauthorized".to_string(),
70    };
71
72    (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
73}