1use axum::{
4 extract::{Request, State},
5 http::header,
6 middleware::Next,
7 response::Response,
8};
9use std::time::Instant;
10
11use crate::auth::{Claims, JwtAuth};
12use crate::error::ApiError;
13use crate::state::AppState;
14pub async fn auth_middleware(
18 State(state): State<AppState>,
19 mut request: Request,
20 next: Next,
21) -> Result<Response, ApiError> {
22 let path = request.uri().path();
24 if path == "/health" || path.starts_with("/public/") {
25 return Ok(next.run(request).await);
26 }
27
28 let auth_header = request
30 .headers()
31 .get(header::AUTHORIZATION)
32 .and_then(|v| v.to_str().ok())
33 .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
34
35 let token = JwtAuth::extract_from_header(auth_header)?;
36 let claims = state.jwt_auth().decode(token)?;
37
38 request.extensions_mut().insert(claims);
40
41 Ok(next.run(request).await)
42}
43
44pub async fn rate_limit_middleware(
46 State(state): State<AppState>,
47 request: Request,
48 next: Next,
49) -> Result<Response, ApiError> {
50 let client_id = request
52 .extensions()
53 .get::<Claims>()
54 .map(|c| c.sub.clone())
55 .unwrap_or_else(|| "anonymous".to_string());
56
57 state
59 .rate_limiter()
60 .try_acquire(&client_id)
61 .await
62 .map_err(|_| ApiError::RateLimited)?;
63
64 Ok(next.run(request).await)
65}
66
67pub async fn tracing_middleware(
69 State(state): State<AppState>,
70 request: Request,
71 next: Next,
72) -> Response {
73 let start = Instant::now();
74 let method = request.method().clone();
75 let uri = request.uri().clone();
76 let path = uri.path().to_string();
77
78 let span = tracing::info_span!(
80 "http_request",
81 method = %method,
82 path = %path,
83 status = tracing::field::Empty,
84 latency_ms = tracing::field::Empty,
85 );
86
87 let response = {
88 let _enter = span.enter();
89 next.run(request).await
90 };
91
92 let latency = start.elapsed();
93 let status = response.status();
94
95 state.metrics().record_llm_call(0, !status.is_success());
97
98 tracing::info!(
100 method = %method,
101 path = %path,
102 status = %status.as_u16(),
103 latency_ms = %latency.as_millis(),
104 "Request completed"
105 );
106
107 response
108}
109
110pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
112 let request_id = uuid::Uuid::new_v4().to_string();
113
114 request
116 .extensions_mut()
117 .insert(RequestId(request_id.clone()));
118
119 let mut response = next.run(request).await;
120
121 response
123 .headers_mut()
124 .insert("X-Request-ID", request_id.parse().unwrap());
125
126 response
127}
128
129#[derive(Clone, Debug)]
131pub struct RequestId(pub String);
132
133pub fn cors_layer() -> tower_http::cors::CorsLayer {
135 tower_http::cors::CorsLayer::new()
136 .allow_origin(tower_http::cors::Any)
137 .allow_methods([
138 axum::http::Method::GET,
139 axum::http::Method::POST,
140 axum::http::Method::PUT,
141 axum::http::Method::DELETE,
142 axum::http::Method::OPTIONS,
143 ])
144 .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT])
145 .max_age(std::time::Duration::from_secs(3600))
146}
147
148#[allow(deprecated)]
150pub fn timeout_layer(duration: std::time::Duration) -> tower_http::timeout::TimeoutLayer {
151 tower_http::timeout::TimeoutLayer::new(duration)
152}
153
154pub fn body_limit_layer(limit: usize) -> tower_http::limit::RequestBodyLimitLayer {
156 tower_http::limit::RequestBodyLimitLayer::new(limit)
157}
158
159pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
162 let mut response = next.run(request).await;
163
164 let headers = response.headers_mut();
165
166 headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
168
169 headers.insert("X-Frame-Options", "DENY".parse().unwrap());
171
172 headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
174
175 headers.insert(
177 "Content-Security-Policy",
178 "default-src 'self'; frame-ancestors 'none'"
179 .parse()
180 .unwrap(),
181 );
182
183 headers.insert(
191 "Referrer-Policy",
192 "strict-origin-when-cross-origin".parse().unwrap(),
193 );
194
195 headers.insert(
197 "Permissions-Policy",
198 "geolocation=(), microphone=(), camera=()".parse().unwrap(),
199 );
200
201 response
202}
203
204#[cfg(test)]
205mod tests {
206 #[test]
207 fn test_request_id() {
208 let id1 = uuid::Uuid::new_v4().to_string();
209 let id2 = uuid::Uuid::new_v4().to_string();
210 assert_ne!(id1, id2);
211 }
212}