Skip to main content

shodh_memory/
middleware.rs

1//! P1.3: HTTP request tracking middleware for observability
2//!
3//! Provides:
4//! - Request ID generation and propagation
5//! - HTTP latency and count metrics
6//! - Path normalization to prevent cardinality explosion
7
8use axum::{
9    extract::Request,
10    http::{header::HeaderValue, StatusCode},
11    middleware::Next,
12    response::Response,
13};
14use std::time::Instant;
15use uuid::Uuid;
16
17/// Request ID extension for correlation across logs and errors
18#[derive(Debug, Clone)]
19pub struct RequestId(pub String);
20
21impl RequestId {
22    /// Generate a new unique request ID
23    pub fn new() -> Self {
24        Self(Uuid::new_v4().to_string())
25    }
26
27    /// Create from existing ID string
28    pub fn from_string(id: String) -> Self {
29        Self(id)
30    }
31
32    /// Get the ID as a string slice
33    pub fn as_str(&self) -> &str {
34        &self.0
35    }
36}
37
38impl Default for RequestId {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl std::fmt::Display for RequestId {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{}", self.0)
47    }
48}
49
50/// Request ID header name (standard header used by many load balancers)
51pub const REQUEST_ID_HEADER: &str = "X-Request-ID";
52
53/// Middleware to add/propagate request IDs for distributed tracing
54///
55/// Behavior:
56/// - If `X-Request-ID` header is present in request, use it
57/// - Otherwise, generate a new UUID v4
58/// - Add the ID to response headers
59/// - Store in request extensions for downstream handlers
60pub async fn request_id(mut req: Request, next: Next) -> Response {
61    // Extract or generate request ID
62    // Sanitize: only allow [a-zA-Z0-9\-_.] to prevent log injection
63    let request_id = req
64        .headers()
65        .get(REQUEST_ID_HEADER)
66        .and_then(|v| v.to_str().ok())
67        .filter(|s| !s.is_empty() && s.len() <= 64)
68        .filter(|s| {
69            s.chars()
70                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
71        })
72        .map(|s| RequestId::from_string(s.to_string()))
73        .unwrap_or_else(RequestId::new);
74
75    // Store in extensions for handlers to access
76    req.extensions_mut().insert(request_id.clone());
77
78    // Add to tracing span
79    let _span = tracing::info_span!(
80        "request",
81        request_id = %request_id,
82        method = %req.method(),
83        path = %req.uri().path()
84    );
85
86    // Process request
87    let mut response = next.run(req).await;
88
89    // Add request ID to response headers
90    if let Ok(header_value) = HeaderValue::from_str(&request_id.0) {
91        response
92            .headers_mut()
93            .insert(REQUEST_ID_HEADER, header_value);
94    }
95
96    response
97}
98
99/// Middleware to add security response headers
100///
101/// Adds:
102/// - X-Content-Type-Options: nosniff (prevent MIME-type sniffing)
103/// - X-Frame-Options: DENY (prevent clickjacking)
104/// - Content-Security-Policy: default-src 'none' (restrict resource loading)
105/// - Cache-Control: no-store (prevent caching of API responses)
106/// - Strict-Transport-Security (HSTS) in production mode only
107pub async fn security_headers(req: Request, next: Next) -> Response {
108    let mut response = next.run(req).await;
109    let headers = response.headers_mut();
110
111    headers.insert(
112        "X-Content-Type-Options",
113        HeaderValue::from_static("nosniff"),
114    );
115    headers.insert("X-Frame-Options", HeaderValue::from_static("DENY"));
116    headers.insert(
117        "Content-Security-Policy",
118        HeaderValue::from_static("default-src 'none'"),
119    );
120    headers.insert("Cache-Control", HeaderValue::from_static("no-store"));
121
122    // HSTS in production only (requires HTTPS to be meaningful)
123    if crate::auth::is_production_mode() {
124        headers.insert(
125            "Strict-Transport-Security",
126            HeaderValue::from_static("max-age=31536000; includeSubDomains"),
127        );
128    }
129
130    response
131}
132
133/// Slow request warning threshold (seconds)
134const SLOW_REQUEST_THRESHOLD_SECS: f64 = 30.0;
135
136/// P1.3: Middleware to track HTTP request latency and counts
137pub async fn track_metrics(req: Request, next: Next) -> Result<Response, StatusCode> {
138    let start = Instant::now();
139    let method = req.method().to_string();
140    let path = req.uri().path().to_string();
141
142    // Process request
143    let response = next.run(req).await;
144
145    // Record metrics
146    let duration = start.elapsed().as_secs_f64();
147    let status_code = response.status();
148    let status = status_code.as_u16().to_string();
149
150    // Normalize path to avoid high cardinality (group dynamic IDs)
151    let normalized_path = normalize_path(&path);
152
153    // Log timeouts (408) for observability - helps identify which endpoints need attention
154    if status_code == StatusCode::REQUEST_TIMEOUT {
155        tracing::error!(
156            method = %method,
157            path = %path,
158            normalized_path = %normalized_path,
159            duration_secs = %duration,
160            "Request timeout - endpoint exceeded configured timeout limit"
161        );
162    } else if duration > SLOW_REQUEST_THRESHOLD_SECS {
163        // Log slow requests that haven't timed out yet
164        tracing::warn!(
165            method = %method,
166            path = %path,
167            normalized_path = %normalized_path,
168            duration_secs = %format!("{:.2}", duration),
169            "Slow request - approaching timeout threshold"
170        );
171    }
172
173    crate::metrics::HTTP_REQUEST_DURATION
174        .with_label_values(&[&method, &normalized_path, &status])
175        .observe(duration);
176
177    crate::metrics::HTTP_REQUESTS_TOTAL
178        .with_label_values(&[&method, &normalized_path, &status])
179        .inc();
180
181    Ok(response)
182}
183
184/// Normalize path to prevent metric cardinality explosion
185/// /api/users/user123/memories -> /api/users/{id}/memories
186fn normalize_path(path: &str) -> String {
187    let parts: Vec<&str> = path.split('/').collect();
188    let mut normalized = Vec::new();
189
190    for part in parts {
191        if part.is_empty() {
192            continue;
193        }
194
195        // Replace UUIDs and IDs with placeholders
196        if is_id(part) {
197            normalized.push("{id}");
198        } else {
199            normalized.push(part);
200        }
201    }
202
203    format!("/{}", normalized.join("/"))
204}
205
206/// Known path segments that should NEVER be treated as IDs (SHO-71)
207const KNOWN_PATH_SEGMENTS: &[&str] = &[
208    "v1",
209    "v2",
210    "v3",
211    "api",
212    "api2",
213    "health",
214    "metrics",
215    "status",
216    "docs",
217    "swagger",
218    "remember",
219    "recall",
220    "forget",
221    "stats",
222    "stream",
223    "events",
224    "settings",
225    "config",
226    "preferences",
227    "notifications",
228    "webhooks",
229    "auth",
230    "login",
231    "logout",
232    "register",
233    "oauth",
234    "token",
235    "refresh",
236    "list",
237    "create",
238    "update",
239    "delete",
240    "search",
241    "query",
242    "sync",
243    "linear",
244    "github",
245    "import",
246    "export",
247    "memories",
248    "context",
249    "session",
250    "working",
251    "longterm",
252    "consolidate",
253    "maintenance",
254    "repair",
255    "rebuild",
256    "graph",
257    "edges",
258    "entities",
259    "reinforce",
260    "coactivation",
261    "introspection",
262    "report",
263    "consolidation",
264    "learning",
265    "tags",
266    "date",
267    "proactive",
268    "verify",
269    "index",
270];
271
272/// Check if a path segment looks like an ID (UUID, numeric, user ID, etc.) (SHO-71)
273///
274/// Improved detection to avoid false positives on legitimate path segments.
275fn is_id(segment: &str) -> bool {
276    // Never treat known path segments as IDs
277    let lower = segment.to_lowercase();
278    if KNOWN_PATH_SEGMENTS.contains(&lower.as_str()) {
279        return false;
280    }
281
282    // UUID pattern: 8-4-4-4-12 hex chars with dashes (36 chars total)
283    if segment.len() == 36 && segment.matches('-').count() == 4 {
284        let parts: Vec<&str> = segment.split('-').collect();
285        if parts.len() == 5
286            && parts[0].len() == 8
287            && parts[1].len() == 4
288            && parts[2].len() == 4
289            && parts[3].len() == 4
290            && parts[4].len() == 12
291            && parts
292                .iter()
293                .all(|p| p.chars().all(|c| c.is_ascii_hexdigit()))
294        {
295            return true;
296        }
297    }
298
299    // Pure numeric ID (any length)
300    if !segment.is_empty() && segment.chars().all(|c| c.is_ascii_digit()) {
301        return true;
302    }
303
304    // Hash-like strings: very long alphanumeric (>40 chars, like SHA256)
305    if segment.len() > 40 && segment.chars().all(|c| c.is_ascii_alphanumeric()) {
306        return true;
307    }
308
309    // ID with common prefixes: user_123, mem_abc, id-456
310    let id_prefixes = [
311        "user_", "user-", "mem_", "mem-", "id_", "id-", "uid_", "uid-", "drone_", "drone-",
312        "robot_", "robot-", "session_", "session-", "mission_", "mission-",
313    ];
314    for prefix in id_prefixes {
315        if lower.starts_with(prefix) {
316            return true;
317        }
318    }
319
320    // Short alphanumeric with majority digits (like "abc123", "x99")
321    if segment.len() >= 3 && segment.len() <= 20 {
322        let digit_count = segment.chars().filter(|c| c.is_ascii_digit()).count();
323        let alpha_count = segment.chars().filter(|c| c.is_alphabetic()).count();
324        // If >50% digits and has both letters and digits, it's likely an ID
325        if digit_count > 0 && alpha_count > 0 && digit_count >= segment.len() / 2 {
326            return true;
327        }
328    }
329
330    false
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use std::sync::Mutex;
337
338    /// Process-global lock for tests that manipulate environment variables.
339    static ENV_LOCK: Mutex<()> = Mutex::new(());
340
341    #[test]
342    fn test_normalize_path() {
343        // User IDs should be normalized
344        assert_eq!(
345            normalize_path("/api/users/user123/memories"),
346            "/api/users/{id}/memories"
347        );
348        // UUIDs should be normalized
349        assert_eq!(
350            normalize_path("/api/memories/550e8400-e29b-41d4-a716-446655440000"),
351            "/api/memories/{id}"
352        );
353        // Health check should NOT be normalized
354        assert_eq!(normalize_path("/health"), "/health");
355        // Numeric IDs should be normalized
356        assert_eq!(
357            normalize_path("/api/users/12345/stats"),
358            "/api/users/{id}/stats"
359        );
360    }
361
362    #[test]
363    fn test_known_paths_not_normalized() {
364        // SHO-71: Known path segments should NOT be treated as IDs
365        assert_eq!(
366            normalize_path("/api/settings/notifications"),
367            "/api/settings/notifications"
368        );
369        assert_eq!(normalize_path("/api/recall/tags"), "/api/recall/tags");
370        assert_eq!(
371            normalize_path("/api/consolidation/report"),
372            "/api/consolidation/report"
373        );
374        assert_eq!(normalize_path("/api/v2/remember"), "/api/v2/remember");
375    }
376
377    #[test]
378    fn test_uuid_detection() {
379        // Valid UUID should be detected
380        assert!(is_id("550e8400-e29b-41d4-a716-446655440000"));
381        // Invalid UUID-like strings should not match
382        assert!(!is_id("not-a-valid-uuid-at-all"));
383    }
384
385    // ── security_headers ──
386
387    #[tokio::test]
388    async fn security_headers_present_in_dev_mode() {
389        let _guard = ENV_LOCK.lock().unwrap();
390        use axum::body::Body;
391        use axum::http::{Request as HttpRequest, StatusCode};
392        use axum::middleware::from_fn;
393        use axum::routing::get;
394        use axum::Router;
395        use tower::ServiceExt;
396
397        std::env::remove_var("SHODH_ENV");
398
399        let app = Router::new()
400            .route("/test", get(|| async { "ok" }))
401            .layer(from_fn(security_headers));
402
403        let req = HttpRequest::builder()
404            .uri("/test")
405            .body(Body::empty())
406            .unwrap();
407        let resp = app.oneshot(req).await.unwrap();
408
409        assert_eq!(resp.status(), StatusCode::OK);
410        assert_eq!(
411            resp.headers().get("X-Content-Type-Options").unwrap(),
412            "nosniff"
413        );
414        assert_eq!(resp.headers().get("X-Frame-Options").unwrap(), "DENY");
415        assert_eq!(
416            resp.headers().get("Content-Security-Policy").unwrap(),
417            "default-src 'none'"
418        );
419        assert_eq!(resp.headers().get("Cache-Control").unwrap(), "no-store");
420        // HSTS should NOT be present in dev mode
421        assert!(
422            resp.headers().get("Strict-Transport-Security").is_none(),
423            "HSTS should only be set in production"
424        );
425    }
426
427    #[tokio::test]
428    async fn security_headers_hsts_in_production() {
429        let _guard = ENV_LOCK.lock().unwrap();
430        use axum::body::Body;
431        use axum::http::Request as HttpRequest;
432        use axum::middleware::from_fn;
433        use axum::routing::get;
434        use axum::Router;
435        use tower::ServiceExt;
436
437        std::env::set_var("SHODH_ENV", "production");
438
439        let app = Router::new()
440            .route("/test", get(|| async { "ok" }))
441            .layer(from_fn(security_headers));
442
443        let req = HttpRequest::builder()
444            .uri("/test")
445            .body(Body::empty())
446            .unwrap();
447        let resp = app.oneshot(req).await.unwrap();
448
449        assert!(
450            resp.headers().get("Strict-Transport-Security").is_some(),
451            "HSTS should be set in production"
452        );
453        let hsts = resp
454            .headers()
455            .get("Strict-Transport-Security")
456            .unwrap()
457            .to_str()
458            .unwrap();
459        assert!(hsts.contains("max-age="));
460
461        std::env::remove_var("SHODH_ENV");
462    }
463}