1use axum::{
9 extract::Request,
10 http::{header::HeaderValue, StatusCode},
11 middleware::Next,
12 response::Response,
13};
14use std::time::Instant;
15use uuid::Uuid;
16
17#[derive(Debug, Clone)]
19pub struct RequestId(pub String);
20
21impl RequestId {
22 pub fn new() -> Self {
24 Self(Uuid::new_v4().to_string())
25 }
26
27 pub fn from_string(id: String) -> Self {
29 Self(id)
30 }
31
32 pub fn as_str(&self) -> &str {
34 &self.0
35 }
36}
37
38impl Default for RequestId {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl std::fmt::Display for RequestId {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "{}", self.0)
47 }
48}
49
50pub const REQUEST_ID_HEADER: &str = "X-Request-ID";
52
53pub async fn request_id(mut req: Request, next: Next) -> Response {
61 let request_id = req
64 .headers()
65 .get(REQUEST_ID_HEADER)
66 .and_then(|v| v.to_str().ok())
67 .filter(|s| !s.is_empty() && s.len() <= 64)
68 .filter(|s| {
69 s.chars()
70 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
71 })
72 .map(|s| RequestId::from_string(s.to_string()))
73 .unwrap_or_else(RequestId::new);
74
75 req.extensions_mut().insert(request_id.clone());
77
78 let _span = tracing::info_span!(
80 "request",
81 request_id = %request_id,
82 method = %req.method(),
83 path = %req.uri().path()
84 );
85
86 let mut response = next.run(req).await;
88
89 if let Ok(header_value) = HeaderValue::from_str(&request_id.0) {
91 response
92 .headers_mut()
93 .insert(REQUEST_ID_HEADER, header_value);
94 }
95
96 response
97}
98
99pub async fn security_headers(req: Request, next: Next) -> Response {
108 let mut response = next.run(req).await;
109 let headers = response.headers_mut();
110
111 headers.insert(
112 "X-Content-Type-Options",
113 HeaderValue::from_static("nosniff"),
114 );
115 headers.insert("X-Frame-Options", HeaderValue::from_static("DENY"));
116 headers.insert(
117 "Content-Security-Policy",
118 HeaderValue::from_static("default-src 'none'"),
119 );
120 headers.insert("Cache-Control", HeaderValue::from_static("no-store"));
121
122 if crate::auth::is_production_mode() {
124 headers.insert(
125 "Strict-Transport-Security",
126 HeaderValue::from_static("max-age=31536000; includeSubDomains"),
127 );
128 }
129
130 response
131}
132
133const SLOW_REQUEST_THRESHOLD_SECS: f64 = 30.0;
135
136pub async fn track_metrics(req: Request, next: Next) -> Result<Response, StatusCode> {
138 let start = Instant::now();
139 let method = req.method().to_string();
140 let path = req.uri().path().to_string();
141
142 let response = next.run(req).await;
144
145 let duration = start.elapsed().as_secs_f64();
147 let status_code = response.status();
148 let status = status_code.as_u16().to_string();
149
150 let normalized_path = normalize_path(&path);
152
153 if status_code == StatusCode::REQUEST_TIMEOUT {
155 tracing::error!(
156 method = %method,
157 path = %path,
158 normalized_path = %normalized_path,
159 duration_secs = %duration,
160 "Request timeout - endpoint exceeded configured timeout limit"
161 );
162 } else if duration > SLOW_REQUEST_THRESHOLD_SECS {
163 tracing::warn!(
165 method = %method,
166 path = %path,
167 normalized_path = %normalized_path,
168 duration_secs = %format!("{:.2}", duration),
169 "Slow request - approaching timeout threshold"
170 );
171 }
172
173 crate::metrics::HTTP_REQUEST_DURATION
174 .with_label_values(&[&method, &normalized_path, &status])
175 .observe(duration);
176
177 crate::metrics::HTTP_REQUESTS_TOTAL
178 .with_label_values(&[&method, &normalized_path, &status])
179 .inc();
180
181 Ok(response)
182}
183
184fn normalize_path(path: &str) -> String {
187 let parts: Vec<&str> = path.split('/').collect();
188 let mut normalized = Vec::new();
189
190 for part in parts {
191 if part.is_empty() {
192 continue;
193 }
194
195 if is_id(part) {
197 normalized.push("{id}");
198 } else {
199 normalized.push(part);
200 }
201 }
202
203 format!("/{}", normalized.join("/"))
204}
205
206const KNOWN_PATH_SEGMENTS: &[&str] = &[
208 "v1",
209 "v2",
210 "v3",
211 "api",
212 "api2",
213 "health",
214 "metrics",
215 "status",
216 "docs",
217 "swagger",
218 "remember",
219 "recall",
220 "forget",
221 "stats",
222 "stream",
223 "events",
224 "settings",
225 "config",
226 "preferences",
227 "notifications",
228 "webhooks",
229 "auth",
230 "login",
231 "logout",
232 "register",
233 "oauth",
234 "token",
235 "refresh",
236 "list",
237 "create",
238 "update",
239 "delete",
240 "search",
241 "query",
242 "sync",
243 "linear",
244 "github",
245 "import",
246 "export",
247 "memories",
248 "context",
249 "session",
250 "working",
251 "longterm",
252 "consolidate",
253 "maintenance",
254 "repair",
255 "rebuild",
256 "graph",
257 "edges",
258 "entities",
259 "reinforce",
260 "coactivation",
261 "introspection",
262 "report",
263 "consolidation",
264 "learning",
265 "tags",
266 "date",
267 "proactive",
268 "verify",
269 "index",
270];
271
272fn is_id(segment: &str) -> bool {
276 let lower = segment.to_lowercase();
278 if KNOWN_PATH_SEGMENTS.contains(&lower.as_str()) {
279 return false;
280 }
281
282 if segment.len() == 36 && segment.matches('-').count() == 4 {
284 let parts: Vec<&str> = segment.split('-').collect();
285 if parts.len() == 5
286 && parts[0].len() == 8
287 && parts[1].len() == 4
288 && parts[2].len() == 4
289 && parts[3].len() == 4
290 && parts[4].len() == 12
291 && parts
292 .iter()
293 .all(|p| p.chars().all(|c| c.is_ascii_hexdigit()))
294 {
295 return true;
296 }
297 }
298
299 if !segment.is_empty() && segment.chars().all(|c| c.is_ascii_digit()) {
301 return true;
302 }
303
304 if segment.len() > 40 && segment.chars().all(|c| c.is_ascii_alphanumeric()) {
306 return true;
307 }
308
309 let id_prefixes = [
311 "user_", "user-", "mem_", "mem-", "id_", "id-", "uid_", "uid-", "drone_", "drone-",
312 "robot_", "robot-", "session_", "session-", "mission_", "mission-",
313 ];
314 for prefix in id_prefixes {
315 if lower.starts_with(prefix) {
316 return true;
317 }
318 }
319
320 if segment.len() >= 3 && segment.len() <= 20 {
322 let digit_count = segment.chars().filter(|c| c.is_ascii_digit()).count();
323 let alpha_count = segment.chars().filter(|c| c.is_alphabetic()).count();
324 if digit_count > 0 && alpha_count > 0 && digit_count >= segment.len() / 2 {
326 return true;
327 }
328 }
329
330 false
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use std::sync::Mutex;
337
338 static ENV_LOCK: Mutex<()> = Mutex::new(());
340
341 #[test]
342 fn test_normalize_path() {
343 assert_eq!(
345 normalize_path("/api/users/user123/memories"),
346 "/api/users/{id}/memories"
347 );
348 assert_eq!(
350 normalize_path("/api/memories/550e8400-e29b-41d4-a716-446655440000"),
351 "/api/memories/{id}"
352 );
353 assert_eq!(normalize_path("/health"), "/health");
355 assert_eq!(
357 normalize_path("/api/users/12345/stats"),
358 "/api/users/{id}/stats"
359 );
360 }
361
362 #[test]
363 fn test_known_paths_not_normalized() {
364 assert_eq!(
366 normalize_path("/api/settings/notifications"),
367 "/api/settings/notifications"
368 );
369 assert_eq!(normalize_path("/api/recall/tags"), "/api/recall/tags");
370 assert_eq!(
371 normalize_path("/api/consolidation/report"),
372 "/api/consolidation/report"
373 );
374 assert_eq!(normalize_path("/api/v2/remember"), "/api/v2/remember");
375 }
376
377 #[test]
378 fn test_uuid_detection() {
379 assert!(is_id("550e8400-e29b-41d4-a716-446655440000"));
381 assert!(!is_id("not-a-valid-uuid-at-all"));
383 }
384
385 #[tokio::test]
388 async fn security_headers_present_in_dev_mode() {
389 let _guard = ENV_LOCK.lock().unwrap();
390 use axum::body::Body;
391 use axum::http::{Request as HttpRequest, StatusCode};
392 use axum::middleware::from_fn;
393 use axum::routing::get;
394 use axum::Router;
395 use tower::ServiceExt;
396
397 std::env::remove_var("SHODH_ENV");
398
399 let app = Router::new()
400 .route("/test", get(|| async { "ok" }))
401 .layer(from_fn(security_headers));
402
403 let req = HttpRequest::builder()
404 .uri("/test")
405 .body(Body::empty())
406 .unwrap();
407 let resp = app.oneshot(req).await.unwrap();
408
409 assert_eq!(resp.status(), StatusCode::OK);
410 assert_eq!(
411 resp.headers().get("X-Content-Type-Options").unwrap(),
412 "nosniff"
413 );
414 assert_eq!(resp.headers().get("X-Frame-Options").unwrap(), "DENY");
415 assert_eq!(
416 resp.headers().get("Content-Security-Policy").unwrap(),
417 "default-src 'none'"
418 );
419 assert_eq!(resp.headers().get("Cache-Control").unwrap(), "no-store");
420 assert!(
422 resp.headers().get("Strict-Transport-Security").is_none(),
423 "HSTS should only be set in production"
424 );
425 }
426
427 #[tokio::test]
428 async fn security_headers_hsts_in_production() {
429 let _guard = ENV_LOCK.lock().unwrap();
430 use axum::body::Body;
431 use axum::http::Request as HttpRequest;
432 use axum::middleware::from_fn;
433 use axum::routing::get;
434 use axum::Router;
435 use tower::ServiceExt;
436
437 std::env::set_var("SHODH_ENV", "production");
438
439 let app = Router::new()
440 .route("/test", get(|| async { "ok" }))
441 .layer(from_fn(security_headers));
442
443 let req = HttpRequest::builder()
444 .uri("/test")
445 .body(Body::empty())
446 .unwrap();
447 let resp = app.oneshot(req).await.unwrap();
448
449 assert!(
450 resp.headers().get("Strict-Transport-Security").is_some(),
451 "HSTS should be set in production"
452 );
453 let hsts = resp
454 .headers()
455 .get("Strict-Transport-Security")
456 .unwrap()
457 .to_str()
458 .unwrap();
459 assert!(hsts.contains("max-age="));
460
461 std::env::remove_var("SHODH_ENV");
462 }
463}