Skip to main content

spec_ai/spec_ai_api/api/
middleware.rs

1/// API authentication and middleware
2use crate::spec_ai_api::api::auth::AuthService;
3use axum::{
4    extract::{Request, State},
5    http::{header, StatusCode},
6    middleware::Next,
7    response::{IntoResponse, Response},
8    Json,
9};
10use std::sync::Arc;
11
12/// Extension to store authenticated user info in request
13#[derive(Clone, Debug)]
14pub struct AuthenticatedUser {
15    pub username: String,
16}
17
18/// Axum middleware function for bearer token authentication
19///
20/// This middleware:
21/// 1. Checks if auth is enabled in the AuthService
22/// 2. If disabled, allows all requests through
23/// 3. If enabled, validates the Bearer token from Authorization header
24/// 4. Adds AuthenticatedUser extension to request if valid
25pub async fn auth_middleware(
26    State(auth_service): State<Arc<AuthService>>,
27    mut request: Request,
28    next: Next,
29) -> Response {
30    // If auth is not enabled, allow all requests
31    if !auth_service.is_enabled() {
32        return next.run(request).await;
33    }
34
35    // Extract Authorization header
36    let auth_header = request
37        .headers()
38        .get(header::AUTHORIZATION)
39        .and_then(|h| h.to_str().ok());
40
41    let Some(auth_str) = auth_header else {
42        return unauthorized_response("Missing Authorization header");
43    };
44
45    // Must be Bearer token
46    let Some(token) = auth_str.strip_prefix("Bearer ") else {
47        return unauthorized_response(
48            "Invalid Authorization header format. Expected: Bearer <token>",
49        );
50    };
51
52    // Validate token
53    let Some(username) = auth_service.validate_token(token) else {
54        return unauthorized_response("Invalid or expired token");
55    };
56
57    // Add authenticated user to request extensions
58    request
59        .extensions_mut()
60        .insert(AuthenticatedUser { username });
61
62    next.run(request).await
63}
64
65/// Create an unauthorized response with JSON error body
66fn unauthorized_response(message: &str) -> Response {
67    let body = serde_json::json!({
68        "error": message,
69        "code": "unauthorized"
70    });
71
72    (
73        StatusCode::UNAUTHORIZED,
74        [(header::CONTENT_TYPE, "application/json")],
75        Json(body),
76    )
77        .into_response()
78}
79
80/// Legacy API key authentication (kept for backward compatibility)
81pub struct ApiKeyAuth {
82    api_key: Option<String>,
83}
84
85impl ApiKeyAuth {
86    pub fn new(api_key: Option<String>) -> Self {
87        Self { api_key }
88    }
89
90    /// Check if API key authentication is enabled
91    pub fn is_enabled(&self) -> bool {
92        self.api_key.is_some()
93    }
94
95    /// Validate an API key
96    pub fn validate(&self, key: &str) -> bool {
97        match &self.api_key {
98            Some(expected) => expected == key,
99            None => true, // No auth required if not configured
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_api_key_auth_disabled() {
110        let auth = ApiKeyAuth::new(None);
111        assert!(!auth.is_enabled());
112        assert!(auth.validate("any_key"));
113    }
114
115    #[test]
116    fn test_api_key_auth_enabled() {
117        let auth = ApiKeyAuth::new(Some("secret123".to_string()));
118        assert!(auth.is_enabled());
119        assert!(auth.validate("secret123"));
120        assert!(!auth.validate("wrong_key"));
121    }
122
123    #[test]
124    fn test_api_key_validation() {
125        let auth = ApiKeyAuth::new(Some("my-secret-key".to_string()));
126
127        assert!(auth.validate("my-secret-key"));
128        assert!(!auth.validate(""));
129        assert!(!auth.validate("wrong"));
130    }
131}