1use axum::{
4 extract::{Request, State},
5 http::header,
6 middleware::Next,
7 response::Response,
8};
9use std::time::Instant;
10
11use crate::auth::{Claims, JwtAuth};
12use crate::error::ApiError;
13use crate::state::AppState;
14pub async fn auth_middleware(
18 State(state): State<AppState>,
19 mut request: Request,
20 next: Next,
21) -> Result<Response, ApiError> {
22 let path = request.uri().path();
24 if path == "/health" || path.starts_with("/public/") {
25 return Ok(next.run(request).await);
26 }
27
28 let auth_header = request
30 .headers()
31 .get(header::AUTHORIZATION)
32 .and_then(|v| v.to_str().ok())
33 .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
34
35 let token = JwtAuth::extract_from_header(auth_header)?;
36 let claims = state.jwt_auth().decode(token)?;
37
38 request.extensions_mut().insert(claims);
40
41 Ok(next.run(request).await)
42}
43
44pub async fn rate_limit_middleware(
46 State(state): State<AppState>,
47 request: Request,
48 next: Next,
49) -> Result<Response, ApiError> {
50 let client_id = request
52 .extensions()
53 .get::<Claims>()
54 .map(|c| c.sub.clone())
55 .unwrap_or_else(|| "anonymous".to_string());
56
57 state
59 .rate_limiter()
60 .try_acquire(&client_id)
61 .await
62 .map_err(|_| ApiError::RateLimited)?;
63
64 Ok(next.run(request).await)
65}
66
67pub async fn tracing_middleware(
69 State(state): State<AppState>,
70 request: Request,
71 next: Next,
72) -> Response {
73 let start = Instant::now();
74 let method = request.method().clone();
75 let uri = request.uri().clone();
76 let path = uri.path().to_string();
77
78 let span = tracing::info_span!(
80 "http_request",
81 method = %method,
82 path = %path,
83 status = tracing::field::Empty,
84 latency_ms = tracing::field::Empty,
85 );
86
87 let response = {
88 let _enter = span.enter();
89 next.run(request).await
90 };
91
92 let latency = start.elapsed();
93 let status = response.status();
94
95 state.metrics().record_llm_call(0, !status.is_success());
97
98 tracing::info!(
100 method = %method,
101 path = %path,
102 status = %status.as_u16(),
103 latency_ms = %latency.as_millis(),
104 "Request completed"
105 );
106
107 response
108}
109
110pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
112 let request_id = uuid::Uuid::new_v4().to_string();
113
114 request
116 .extensions_mut()
117 .insert(RequestId(request_id.clone()));
118
119 let mut response = next.run(request).await;
120
121 response
123 .headers_mut()
124 .insert("X-Request-ID", request_id.parse().unwrap());
125
126 response
127}
128
129#[derive(Clone, Debug)]
131pub struct RequestId(pub String);
132
133pub fn cors_layer() -> tower_http::cors::CorsLayer {
137 use tower_http::cors::{AllowOrigin, CorsLayer};
138
139 let origins = std::env::var("VEX_CORS_ORIGINS").ok();
140
141 let allow_origin = match origins {
142 Some(origins_str) if !origins_str.is_empty() => {
143 let origins: Vec<axum::http::HeaderValue> = origins_str
144 .split(',')
145 .filter_map(|s| s.trim().parse().ok())
146 .collect();
147 if origins.is_empty() {
148 tracing::warn!("VEX_CORS_ORIGINS is set but contains no valid origins, using restrictive default");
149 AllowOrigin::exact("https://localhost".parse().unwrap())
150 } else {
151 tracing::info!("CORS configured for {} origin(s)", origins.len());
152 AllowOrigin::list(origins)
153 }
154 }
155 _ => {
156 tracing::warn!("VEX_CORS_ORIGINS not set, using restrictive CORS (localhost only)");
158 AllowOrigin::exact("https://localhost".parse().unwrap())
159 }
160 };
161
162 CorsLayer::new()
163 .allow_origin(allow_origin)
164 .allow_methods([
165 axum::http::Method::GET,
166 axum::http::Method::POST,
167 axum::http::Method::PUT,
168 axum::http::Method::DELETE,
169 axum::http::Method::OPTIONS,
170 ])
171 .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
172 .max_age(std::time::Duration::from_secs(3600))
173}
174
175#[allow(deprecated)]
177pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
178 tower_http::timeout::TimeoutLayer::new(duration)
179}
180
181pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
183 tower_http::limit::RequestBodyLimitLayer::new(limit)
184}
185
186pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
189 let mut response = next.run(request).await;
190
191 let headers = response.headers_mut();
192
193 headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
195
196 headers.insert("X-Frame-Options", "DENY".parse().unwrap());
198
199 headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
201
202 headers.insert(
204 "Content-Security-Policy",
205 "default-src 'self'; frame-ancestors 'none'"
206 .parse()
207 .unwrap(),
208 );
209
210 if std::env::var("VEX_ENABLE_HSTS").is_ok() {
212 headers.insert(
213 "Strict-Transport-Security",
214 "max-age=31536000; includeSubDomains".parse().unwrap(),
215 );
216 }
217
218 headers.insert(
220 "Referrer-Policy",
221 "strict-origin-when-cross-origin".parse().unwrap(),
222 );
223
224 headers.insert(
226 "Permissions-Policy",
227 "geolocation=(), microphone=(), camera=()".parse().unwrap(),
228 );
229
230 response
231}
232
233#[cfg(test)]
234mod tests {
235 #[test]
236 fn test_request_id() {
237 let id1 = uuid::Uuid::new_v4().to_string();
238 let id2 = uuid::Uuid::new_v4().to_string();
239 assert_ne!(id1, id2);
240 }
241}