spec_ai/spec_ai_api/api/
middleware.rs1use crate::spec_ai_api::api::auth::AuthService;
3use axum::{
4 extract::{Request, State},
5 http::{header, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8 Json,
9};
10use std::sync::Arc;
11
12#[derive(Clone, Debug)]
14pub struct AuthenticatedUser {
15 pub username: String,
16}
17
18pub async fn auth_middleware(
26 State(auth_service): State<Arc<AuthService>>,
27 mut request: Request,
28 next: Next,
29) -> Response {
30 if !auth_service.is_enabled() {
32 return next.run(request).await;
33 }
34
35 let auth_header = request
37 .headers()
38 .get(header::AUTHORIZATION)
39 .and_then(|h| h.to_str().ok());
40
41 let Some(auth_str) = auth_header else {
42 return unauthorized_response("Missing Authorization header");
43 };
44
45 let Some(token) = auth_str.strip_prefix("Bearer ") else {
47 return unauthorized_response(
48 "Invalid Authorization header format. Expected: Bearer <token>",
49 );
50 };
51
52 let Some(username) = auth_service.validate_token(token) else {
54 return unauthorized_response("Invalid or expired token");
55 };
56
57 request
59 .extensions_mut()
60 .insert(AuthenticatedUser { username });
61
62 next.run(request).await
63}
64
65fn unauthorized_response(message: &str) -> Response {
67 let body = serde_json::json!({
68 "error": message,
69 "code": "unauthorized"
70 });
71
72 (
73 StatusCode::UNAUTHORIZED,
74 [(header::CONTENT_TYPE, "application/json")],
75 Json(body),
76 )
77 .into_response()
78}
79
80pub struct ApiKeyAuth {
82 api_key: Option<String>,
83}
84
85impl ApiKeyAuth {
86 pub fn new(api_key: Option<String>) -> Self {
87 Self { api_key }
88 }
89
90 pub fn is_enabled(&self) -> bool {
92 self.api_key.is_some()
93 }
94
95 pub fn validate(&self, key: &str) -> bool {
97 match &self.api_key {
98 Some(expected) => expected == key,
99 None => true, }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_api_key_auth_disabled() {
110 let auth = ApiKeyAuth::new(None);
111 assert!(!auth.is_enabled());
112 assert!(auth.validate("any_key"));
113 }
114
115 #[test]
116 fn test_api_key_auth_enabled() {
117 let auth = ApiKeyAuth::new(Some("secret123".to_string()));
118 assert!(auth.is_enabled());
119 assert!(auth.validate("secret123"));
120 assert!(!auth.validate("wrong_key"));
121 }
122
123 #[test]
124 fn test_api_key_validation() {
125 let auth = ApiKeyAuth::new(Some("my-secret-key".to_string()));
126
127 assert!(auth.validate("my-secret-key"));
128 assert!(!auth.validate(""));
129 assert!(!auth.validate("wrong"));
130 }
131}