vex_api/
middleware.rs

1//! Tower middleware for VEX API
2
3use axum::{
4    extract::{Request, State},
5    http::header,
6    middleware::Next,
7    response::Response,
8};
9use std::time::Instant;
10
11use crate::auth::{Claims, JwtAuth};
12use crate::error::ApiError;
13use crate::state::AppState;
14// use vex_llm::{RateLimiter, Metrics}; // No longer needed directly here? No, rate_limiter is used.
15
16/// Authentication middleware
17pub async fn auth_middleware(
18    State(state): State<AppState>,
19    mut request: Request,
20    next: Next,
21) -> Result<Response, ApiError> {
22    // Skip auth for health check and public endpoints
23    let path = request.uri().path();
24    if path == "/health" || path.starts_with("/public/") {
25        return Ok(next.run(request).await);
26    }
27
28    // Extract token from header
29    let auth_header = request
30        .headers()
31        .get(header::AUTHORIZATION)
32        .and_then(|v| v.to_str().ok())
33        .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
34
35    let token = JwtAuth::extract_from_header(auth_header)?;
36    let claims = state.jwt_auth().decode(token)?;
37
38    // Insert claims into request extensions for handlers
39    request.extensions_mut().insert(claims);
40
41    Ok(next.run(request).await)
42}
43
44/// Rate limiting middleware
45pub async fn rate_limit_middleware(
46    State(state): State<AppState>,
47    request: Request,
48    next: Next,
49) -> Result<Response, ApiError> {
50    // Get client identifier (from claims or IP)
51    let client_id = request
52        .extensions()
53        .get::<Claims>()
54        .map(|c| c.sub.clone())
55        .unwrap_or_else(|| "anonymous".to_string());
56
57    // Check rate limit
58    state
59        .rate_limiter()
60        .try_acquire(&client_id)
61        .await
62        .map_err(|_| ApiError::RateLimited)?;
63
64    Ok(next.run(request).await)
65}
66
67/// Request tracing middleware
68pub async fn tracing_middleware(
69    State(state): State<AppState>,
70    request: Request,
71    next: Next,
72) -> Response {
73    let start = Instant::now();
74    let method = request.method().clone();
75    let uri = request.uri().clone();
76    let path = uri.path().to_string();
77
78    // Create span for this request
79    let span = tracing::info_span!(
80        "http_request",
81        method = %method,
82        path = %path,
83        status = tracing::field::Empty,
84        latency_ms = tracing::field::Empty,
85    );
86
87    let response = {
88        let _enter = span.enter();
89        next.run(request).await
90    };
91
92    let latency = start.elapsed();
93    let status = response.status();
94
95    // Record metrics
96    state.metrics().record_llm_call(0, !status.is_success());
97
98    // Log request
99    tracing::info!(
100        method = %method,
101        path = %path,
102        status = %status.as_u16(),
103        latency_ms = %latency.as_millis(),
104        "Request completed"
105    );
106
107    response
108}
109
110/// Request ID middleware
111pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
112    let request_id = uuid::Uuid::new_v4().to_string();
113
114    // Add to request extensions
115    request
116        .extensions_mut()
117        .insert(RequestId(request_id.clone()));
118
119    let mut response = next.run(request).await;
120
121    // Add to response headers
122    response
123        .headers_mut()
124        .insert("X-Request-ID", request_id.parse().unwrap());
125
126    response
127}
128
129/// Request ID wrapper
130#[derive(Clone, Debug)]
131pub struct RequestId(pub String);
132
133/// CORS configuration helper
134/// Reads allowed origins from VEX_CORS_ORIGINS env var (comma-separated)
135/// Falls back to restrictive default if not set
136pub fn cors_layer() -> tower_http::cors::CorsLayer {
137    use tower_http::cors::{AllowOrigin, CorsLayer};
138
139    let origins = std::env::var("VEX_CORS_ORIGINS").ok();
140
141    let allow_origin = match origins {
142        Some(origins_str) if !origins_str.is_empty() => {
143            let origins: Vec<axum::http::HeaderValue> = origins_str
144                .split(',')
145                .filter_map(|s| s.trim().parse().ok())
146                .collect();
147            if origins.is_empty() {
148                tracing::warn!("VEX_CORS_ORIGINS is set but contains no valid origins, using restrictive default");
149                AllowOrigin::exact("https://localhost".parse().unwrap())
150            } else {
151                tracing::info!("CORS configured for {} origin(s)", origins.len());
152                AllowOrigin::list(origins)
153            }
154        }
155        _ => {
156            // No CORS_ORIGINS set - use restrictive default for security
157            tracing::warn!("VEX_CORS_ORIGINS not set, using restrictive CORS (localhost only)");
158            AllowOrigin::exact("https://localhost".parse().unwrap())
159        }
160    };
161
162    CorsLayer::new()
163        .allow_origin(allow_origin)
164        .allow_methods([
165            axum::http::Method::GET,
166            axum::http::Method::POST,
167            axum::http::Method::PUT,
168            axum::http::Method::DELETE,
169            axum::http::Method::OPTIONS,
170        ])
171        .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
172        .max_age(std::time::Duration::from_secs(3600))
173}
174
175/// Timeout layer helper
176#[allow(deprecated)]
177pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
178    tower_http::timeout::TimeoutLayer::new(duration)
179}
180
181/// Request body size limit
182pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
183    tower_http::limit::RequestBodyLimitLayer::new(limit)
184}
185
186/// Security headers middleware
187/// Adds standard security headers to all responses
188pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
189    let mut response = next.run(request).await;
190
191    let headers = response.headers_mut();
192
193    // Prevent MIME type sniffing
194    headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
195
196    // Prevent clickjacking
197    headers.insert("X-Frame-Options", "DENY".parse().unwrap());
198
199    // XSS protection (legacy, but still useful)
200    headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
201
202    // Content Security Policy
203    headers.insert(
204        "Content-Security-Policy",
205        "default-src 'self'; frame-ancestors 'none'"
206            .parse()
207            .unwrap(),
208    );
209
210    // HSTS - Enable in production by setting VEX_ENABLE_HSTS=1
211    if std::env::var("VEX_ENABLE_HSTS").is_ok() {
212        headers.insert(
213            "Strict-Transport-Security",
214            "max-age=31536000; includeSubDomains".parse().unwrap(),
215        );
216    }
217
218    // Referrer policy
219    headers.insert(
220        "Referrer-Policy",
221        "strict-origin-when-cross-origin".parse().unwrap(),
222    );
223
224    // Permissions policy
225    headers.insert(
226        "Permissions-Policy",
227        "geolocation=(), microphone=(), camera=()".parse().unwrap(),
228    );
229
230    response
231}
232
233#[cfg(test)]
234mod tests {
235    #[test]
236    fn test_request_id() {
237        let id1 = uuid::Uuid::new_v4().to_string();
238        let id2 = uuid::Uuid::new_v4().to_string();
239        assert_ne!(id1, id2);
240    }
241}