vex_api/
middleware.rs

1//! Tower middleware for VEX API
2
3use axum::{
4    extract::{Request, State},
5    http::header,
6    middleware::Next,
7    response::Response,
8};
9use std::time::Instant;
10
11use crate::auth::{Claims, JwtAuth};
12use crate::error::ApiError;
13use crate::state::AppState;
14// use vex_llm::{RateLimiter, Metrics}; // No longer needed directly here? No, rate_limiter is used.
15
16/// Authentication middleware
17pub async fn auth_middleware(
18    State(state): State<AppState>,
19    mut request: Request,
20    next: Next,
21) -> Result<Response, ApiError> {
22    // Skip auth for health check and public endpoints
23    let path = request.uri().path();
24    if path == "/health" || path.starts_with("/public/") {
25        return Ok(next.run(request).await);
26    }
27
28    // Extract token from header
29    let auth_header = request
30        .headers()
31        .get(header::AUTHORIZATION)
32        .and_then(|v| v.to_str().ok())
33        .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
34
35    let token = JwtAuth::extract_from_header(auth_header)?;
36    let claims = state.jwt_auth().decode(token)?;
37
38    // Insert claims into request extensions for handlers
39    request.extensions_mut().insert(claims);
40
41    Ok(next.run(request).await)
42}
43
44/// Rate limiting middleware
45pub async fn rate_limit_middleware(
46    State(state): State<AppState>,
47    request: Request,
48    next: Next,
49) -> Result<Response, ApiError> {
50    // Get client identifier (from claims or IP)
51    let client_id = request
52        .extensions()
53        .get::<Claims>()
54        .map(|c| c.sub.clone())
55        .unwrap_or_else(|| "anonymous".to_string());
56
57    // Check rate limit
58    state
59        .rate_limiter()
60        .try_acquire(&client_id)
61        .await
62        .map_err(|_| ApiError::RateLimited)?;
63
64    Ok(next.run(request).await)
65}
66
67/// Request tracing middleware
68pub async fn tracing_middleware(
69    State(state): State<AppState>,
70    request: Request,
71    next: Next,
72) -> Response {
73    let start = Instant::now();
74    let method = request.method().clone();
75    let uri = request.uri().clone();
76    let path = uri.path().to_string();
77
78    // Create span for this request
79    let span = tracing::info_span!(
80        "http_request",
81        method = %method,
82        path = %path,
83        status = tracing::field::Empty,
84        latency_ms = tracing::field::Empty,
85    );
86
87    let response = {
88        let _enter = span.enter();
89        next.run(request).await
90    };
91
92    let latency = start.elapsed();
93    let status = response.status();
94
95    // Record metrics
96    state.metrics().record_llm_call(0, !status.is_success());
97
98    // Log request
99    tracing::info!(
100        method = %method,
101        path = %path,
102        status = %status.as_u16(),
103        latency_ms = %latency.as_millis(),
104        "Request completed"
105    );
106
107    response
108}
109
110/// Request ID middleware
111pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
112    let request_id = uuid::Uuid::new_v4().to_string();
113
114    // Add to request extensions
115    request
116        .extensions_mut()
117        .insert(RequestId(request_id.clone()));
118
119    let mut response = next.run(request).await;
120
121    // Add to response headers
122    response
123        .headers_mut()
124        .insert("X-Request-ID", request_id.parse().unwrap());
125
126    response
127}
128
129/// Request ID wrapper
130#[derive(Clone, Debug)]
131pub struct RequestId(pub String);
132
133/// CORS configuration helper
134pub fn cors_layer() -> tower_http::cors::CorsLayer {
135    tower_http::cors::CorsLayer::new()
136        .allow_origin(tower_http::cors::Any)
137        .allow_methods([
138            axum::http::Method::GET,
139            axum::http::Method::POST,
140            axum::http::Method::PUT,
141            axum::http::Method::DELETE,
142            axum::http::Method::OPTIONS,
143        ])
144        .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
145        .max_age(std::time::Duration::from_secs(3600))
146}
147
148/// Timeout layer helper
149#[allow(deprecated)]
150pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
151    tower_http::timeout::TimeoutLayer::new(duration)
152}
153
154/// Request body size limit
155pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
156    tower_http::limit::RequestBodyLimitLayer::new(limit)
157}
158
159/// Security headers middleware
160/// Adds standard security headers to all responses
161pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
162    let mut response = next.run(request).await;
163
164    let headers = response.headers_mut();
165
166    // Prevent MIME type sniffing
167    headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
168
169    // Prevent clickjacking
170    headers.insert("X-Frame-Options", "DENY".parse().unwrap());
171
172    // XSS protection (legacy, but still useful)
173    headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
174
175    // Content Security Policy
176    headers.insert(
177        "Content-Security-Policy",
178        "default-src 'self'; frame-ancestors 'none'"
179            .parse()
180            .unwrap(),
181    );
182
183    // HSTS (only enable in production with HTTPS)
184    // headers.insert(
185    //     "Strict-Transport-Security",
186    //     "max-age=31536000; includeSubDomains".parse().unwrap(),
187    // );
188
189    // Referrer policy
190    headers.insert(
191        "Referrer-Policy",
192        "strict-origin-when-cross-origin".parse().unwrap(),
193    );
194
195    // Permissions policy
196    headers.insert(
197        "Permissions-Policy",
198        "geolocation=(), microphone=(), camera=()".parse().unwrap(),
199    );
200
201    response
202}
203
204#[cfg(test)]
205mod tests {
206    #[test]
207    fn test_request_id() {
208        let id1 = uuid::Uuid::new_v4().to_string();
209        let id2 = uuid::Uuid::new_v4().to_string();
210        assert_ne!(id1, id2);
211    }
212}