turbomcp_server/
middleware.rs

1//! Middleware system for request/response processing
2
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use turbomcp_core::RequestContext;
9use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse};
10
11use crate::{ServerError, ServerResult};
12
13/// Middleware trait for processing requests and responses
14#[async_trait]
15pub trait Middleware: Send + Sync {
16    /// Process request before routing
17    async fn process_request(
18        &self,
19        request: &mut JsonRpcRequest,
20        ctx: &mut RequestContext,
21    ) -> ServerResult<()>;
22
23    /// Process response after routing
24    async fn process_response(
25        &self,
26        response: &mut JsonRpcResponse,
27        ctx: &RequestContext,
28    ) -> ServerResult<()>;
29
30    /// Get middleware name
31    fn name(&self) -> &str;
32
33    /// Get middleware priority (lower numbers = higher priority)
34    fn priority(&self) -> u32 {
35        100
36    }
37
38    /// Check if middleware is enabled
39    fn enabled(&self) -> bool {
40        true
41    }
42}
43
44/// Middleware stack for composing multiple middleware
45pub struct MiddlewareStack {
46    /// Ordered list of middleware
47    middleware: Vec<Arc<dyn Middleware>>,
48    /// Stack configuration
49    config: StackConfig,
50}
51
52impl std::fmt::Debug for MiddlewareStack {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("MiddlewareStack")
55            .field("middleware_count", &self.middleware.len())
56            .field("config", &self.config)
57            .finish()
58    }
59}
60
61/// Middleware stack configuration
62#[derive(Debug, Clone)]
63pub struct StackConfig {
64    /// Enable middleware metrics
65    pub enable_metrics: bool,
66    /// Enable middleware tracing
67    pub enable_tracing: bool,
68    /// Middleware timeout in milliseconds
69    pub timeout_ms: u64,
70    /// Enable error recovery
71    pub enable_recovery: bool,
72}
73
74impl Default for StackConfig {
75    fn default() -> Self {
76        Self {
77            enable_metrics: true,
78            enable_tracing: true,
79            timeout_ms: 5_000,
80            enable_recovery: true,
81        }
82    }
83}
84
85impl MiddlewareStack {
86    /// Create a new middleware stack
87    #[must_use]
88    pub fn new() -> Self {
89        Self {
90            middleware: Vec::new(),
91            config: StackConfig::default(),
92        }
93    }
94
95    /// Create a stack with configuration
96    #[must_use]
97    pub fn with_config(config: StackConfig) -> Self {
98        Self {
99            middleware: Vec::new(),
100            config,
101        }
102    }
103
104    /// Add middleware to the stack
105    pub fn add<M>(&mut self, middleware: M)
106    where
107        M: Middleware + 'static,
108    {
109        self.middleware.push(Arc::new(middleware));
110        self.sort_by_priority();
111    }
112
113    /// Remove middleware by name
114    pub fn remove(&mut self, name: &str) {
115        self.middleware.retain(|m| m.name() != name);
116    }
117
118    /// Process request through all middleware
119    pub async fn process_request(
120        &self,
121        mut request: JsonRpcRequest,
122        mut ctx: RequestContext,
123    ) -> ServerResult<(JsonRpcRequest, RequestContext)> {
124        // Record a start timestamp for end-to-end latency
125        let global_start = Instant::now();
126        for middleware in &self.middleware {
127            if !middleware.enabled() {
128                continue;
129            }
130
131            let start = Instant::now();
132
133            // Apply timeout if configured
134            let result = if self.config.timeout_ms > 0 {
135                tokio::time::timeout(
136                    Duration::from_millis(self.config.timeout_ms),
137                    middleware.process_request(&mut request, &mut ctx),
138                )
139                .await
140            } else {
141                Ok(middleware.process_request(&mut request, &mut ctx).await)
142            };
143
144            let duration = start.elapsed();
145
146            if self.config.enable_tracing {
147                tracing::debug!(
148                    middleware = middleware.name(),
149                    duration_ms = duration.as_millis(),
150                    "Processed request through middleware"
151                );
152            }
153
154            match result {
155                Ok(Ok(())) => continue,
156                Ok(Err(e)) => {
157                    if self.config.enable_recovery {
158                        tracing::warn!(
159                            middleware = middleware.name(),
160                            error = %e,
161                            "Middleware error, continuing with recovery"
162                        );
163                        continue;
164                    }
165                    return Err(ServerError::middleware(middleware.name(), e.to_string()));
166                }
167                Err(_) => {
168                    let _error = format!(
169                        "Middleware '{}' timed out after {}ms",
170                        middleware.name(),
171                        self.config.timeout_ms
172                    );
173                    if self.config.enable_recovery {
174                        tracing::warn!(
175                            middleware = middleware.name(),
176                            "Middleware timeout, continuing"
177                        );
178                        continue;
179                    }
180                    return Err(ServerError::timeout("middleware", self.config.timeout_ms));
181                }
182            }
183        }
184
185        // Correlation/request identifiers
186        let correlation_id = ctx
187            .metadata
188            .get("correlation_id")
189            .and_then(|v| v.as_str())
190            .map_or_else(
191                || uuid::Uuid::new_v4().to_string(),
192                std::string::ToString::to_string,
193            );
194        ctx = ctx.with_metadata("correlation_id", correlation_id);
195
196        // Store precise start time and monotonic start in metadata
197        let start_ns = start_ts();
198        let request_id = ctx.request_id.clone();
199        ctx = ctx.with_metadata("request_start_ns", start_ns);
200        ctx = ctx.with_metadata("request_id", request_id);
201        // Also include wall-clock duration so far (best-effort)
202        ctx = ctx.with_metadata(
203            "middleware_time_ms",
204            global_start.elapsed().as_millis() as u64,
205        );
206        Ok((request, ctx))
207    }
208
209    /// Process response through all middleware (in reverse order)
210    pub async fn process_response(
211        &self,
212        mut response: JsonRpcResponse,
213        ctx: &RequestContext,
214    ) -> ServerResult<JsonRpcResponse> {
215        for middleware in self.middleware.iter().rev() {
216            if !middleware.enabled() {
217                continue;
218            }
219
220            let start = Instant::now();
221
222            // Apply timeout if configured
223            let result = if self.config.timeout_ms > 0 {
224                tokio::time::timeout(
225                    Duration::from_millis(self.config.timeout_ms),
226                    middleware.process_response(&mut response, ctx),
227                )
228                .await
229            } else {
230                Ok(middleware.process_response(&mut response, ctx).await)
231            };
232
233            let duration = start.elapsed();
234
235            if self.config.enable_tracing {
236                tracing::debug!(
237                    middleware = middleware.name(),
238                    duration_ms = duration.as_millis(),
239                    "Processed response through middleware"
240                );
241            }
242
243            match result {
244                Ok(Ok(())) => continue,
245                Ok(Err(e)) => {
246                    if self.config.enable_recovery {
247                        tracing::warn!(
248                            middleware = middleware.name(),
249                            error = %e,
250                            "Middleware error in response processing, continuing"
251                        );
252                        continue;
253                    }
254                    return Err(ServerError::middleware(middleware.name(), e.to_string()));
255                }
256                Err(_) => {
257                    if self.config.enable_recovery {
258                        tracing::warn!(
259                            middleware = middleware.name(),
260                            "Middleware timeout in response processing, continuing"
261                        );
262                        continue;
263                    }
264                    return Err(ServerError::timeout("middleware", self.config.timeout_ms));
265                }
266            }
267        }
268
269        // Compute end-to-end latency if start_ns present
270        if let Some(ns) = ctx
271            .metadata
272            .get("request_start_ns")
273            .and_then(serde_json::Value::as_u64)
274        {
275            let end_ns = start_ts();
276            let elapsed_ns = end_ns.saturating_sub(ns);
277            let latency_ms = (elapsed_ns as f64) / 1_000_000.0;
278            tracing::debug!(
279                correlation_id = ctx.metadata.get("correlation_id").and_then(|v| v.as_str()),
280                request_id = %ctx.request_id,
281                latency_ms,
282                "Request completed with latency"
283            );
284        }
285        Ok(response)
286    }
287
288    /// Get middleware count
289    #[must_use]
290    pub fn len(&self) -> usize {
291        self.middleware.len()
292    }
293
294    /// Check if stack is empty
295    #[must_use]
296    pub fn is_empty(&self) -> bool {
297        self.middleware.is_empty()
298    }
299
300    /// List all middleware names
301    #[must_use]
302    pub fn list_middleware(&self) -> Vec<&str> {
303        self.middleware.iter().map(|m| m.name()).collect()
304    }
305
306    fn sort_by_priority(&mut self) {
307        self.middleware.sort_by_key(|m| m.priority());
308    }
309}
310
311impl Default for MiddlewareStack {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317/// Authentication middleware
318pub struct AuthenticationMiddleware {
319    /// Authentication provider
320    provider: Arc<dyn AuthProvider>,
321    /// Middleware configuration
322    config: AuthConfig,
323}
324
325impl std::fmt::Debug for AuthenticationMiddleware {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        f.debug_struct("AuthenticationMiddleware")
328            .field("config", &self.config)
329            .finish()
330    }
331}
332
333/// Authentication configuration
334#[derive(Debug, Clone)]
335pub struct AuthConfig {
336    /// Skip authentication for certain methods
337    pub skip_methods: Vec<String>,
338    /// Authentication scheme
339    pub scheme: AuthScheme,
340    /// Token expiry duration
341    pub token_expiry: Duration,
342}
343
344/// Authentication schemes
345#[derive(Debug, Clone)]
346pub enum AuthScheme {
347    /// Bearer token authentication
348    Bearer,
349    /// API key authentication
350    ApiKey,
351    /// Basic authentication
352    Basic,
353    /// Custom authentication
354    Custom(String),
355}
356
357/// Authentication provider trait
358#[async_trait]
359pub trait AuthProvider: Send + Sync {
360    /// Authenticate a request
361    async fn authenticate(&self, request: &JsonRpcRequest) -> ServerResult<AuthContext>;
362
363    /// Validate token
364    async fn validate_token(&self, token: &str) -> ServerResult<AuthContext>;
365}
366
367/// Authentication context
368#[derive(Debug, Clone)]
369pub struct AuthContext {
370    /// User ID
371    pub user_id: String,
372    /// User roles
373    pub roles: Vec<String>,
374    /// Token expiry
375    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
376    /// Additional claims
377    pub claims: HashMap<String, serde_json::Value>,
378}
379
380impl AuthenticationMiddleware {
381    /// Create new authentication middleware
382    pub fn new<P>(provider: P) -> Self
383    where
384        P: AuthProvider + 'static,
385    {
386        Self {
387            provider: Arc::new(provider),
388            config: AuthConfig {
389                skip_methods: vec!["initialize".to_string()],
390                scheme: AuthScheme::Bearer,
391                token_expiry: Duration::from_secs(3600),
392            },
393        }
394    }
395
396    /// Create with configuration
397    pub fn with_config<P>(provider: P, config: AuthConfig) -> Self
398    where
399        P: AuthProvider + 'static,
400    {
401        Self {
402            provider: Arc::new(provider),
403            config,
404        }
405    }
406}
407
408#[async_trait]
409impl Middleware for AuthenticationMiddleware {
410    async fn process_request(
411        &self,
412        request: &mut JsonRpcRequest,
413        _ctx: &mut RequestContext,
414    ) -> ServerResult<()> {
415        // Skip authentication for certain methods
416        if self.config.skip_methods.contains(&request.method) {
417            return Ok(());
418        }
419
420        match self.provider.authenticate(request).await {
421            Ok(auth_ctx) => {
422                // Propagate auth into RequestContext
423                _ctx.user_id = Some(auth_ctx.user_id.clone());
424                let meta = std::sync::Arc::make_mut(&mut _ctx.metadata);
425                meta.insert("authenticated".to_string(), serde_json::json!(true));
426                meta.insert(
427                    "auth".to_string(),
428                    serde_json::json!({
429                        "user_id": auth_ctx.user_id,
430                        "roles": auth_ctx.roles,
431                        "expires_at": auth_ctx.expires_at.map(|t| t.to_rfc3339()),
432                        "claims": auth_ctx.claims,
433                    }),
434                );
435                Ok(())
436            }
437            Err(e) => Err(ServerError::authentication(format!(
438                "Authentication failed: {e}"
439            ))),
440        }
441    }
442
443    async fn process_response(
444        &self,
445        _response: &mut JsonRpcResponse,
446        _ctx: &RequestContext,
447    ) -> ServerResult<()> {
448        Ok(())
449    }
450
451    fn name(&self) -> &'static str {
452        "authentication"
453    }
454
455    fn priority(&self) -> u32 {
456        10 // High priority
457    }
458}
459
460/// Rate limiting middleware
461#[derive(Debug)]
462pub struct RateLimitMiddleware {
463    /// Rate limiter
464    limiter: Arc<RateLimiter>,
465    /// Rate limit configuration
466    config: RateLimitConfig,
467}
468
469/// Rate limiting configuration
470#[derive(Debug, Clone)]
471pub struct RateLimitConfig {
472    /// Requests per second limit
473    pub requests_per_second: u32,
474    /// Burst capacity
475    pub burst_capacity: u32,
476    /// Rate limit key extractor
477    pub key_extractor: KeyExtractor,
478}
479
480/// Key extraction strategies for rate limiting
481#[derive(Debug, Clone)]
482pub enum KeyExtractor {
483    /// Use client IP address
484    ClientIp,
485    /// Use user ID from auth context
486    UserId,
487    /// Use API key
488    ApiKey,
489    /// Use custom field
490    Custom(String),
491    /// Global rate limit
492    Global,
493}
494
495/// Rate limiter implementation
496#[derive(Debug)]
497pub struct RateLimiter {
498    /// Rate limit entries
499    entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
500    /// Cleanup task handle (None in tests)
501    _cleanup_handle: Option<tokio::task::JoinHandle<()>>,
502}
503
504/// Rate limit entry
505#[derive(Debug, Clone)]
506struct RateLimitEntry {
507    /// Available tokens
508    tokens: u32,
509    /// Last refill time
510    last_refill: Instant,
511    /// Entry expiry
512    expires_at: Instant,
513}
514
515impl RateLimiter {
516    /// Create new rate limiter with background cleanup task
517    #[must_use]
518    pub fn new(_requests_per_second: u32, _burst_capacity: u32) -> Self {
519        let entries = Arc::new(RwLock::new(HashMap::<String, RateLimitEntry>::new()));
520
521        // Cleanup task
522        let cleanup_entries = Arc::clone(&entries);
523        let cleanup_handle = tokio::spawn(async move {
524            let mut interval = tokio::time::interval(Duration::from_secs(60));
525            loop {
526                interval.tick().await;
527                let now = Instant::now();
528                let mut entries = cleanup_entries.write().await;
529                entries.retain(|_, entry| entry.expires_at > now);
530            }
531        });
532
533        Self {
534            entries,
535            _cleanup_handle: Some(cleanup_handle),
536        }
537    }
538
539    /// Create new rate limiter for testing (no background tasks)
540    #[must_use]
541    #[cfg(test)]
542    pub fn new_for_testing(_requests_per_second: u32, _burst_capacity: u32) -> Self {
543        let entries = Arc::new(RwLock::new(HashMap::<String, RateLimitEntry>::new()));
544
545        Self {
546            entries,
547            _cleanup_handle: None, // No cleanup task in tests
548        }
549    }
550
551    /// Check if request is allowed
552    pub async fn check_rate_limit(
553        &self,
554        key: &str,
555        requests_per_second: u32,
556        burst_capacity: u32,
557    ) -> bool {
558        let mut entries = self.entries.write().await;
559        let now = Instant::now();
560
561        let entry = entries.entry(key.to_string()).or_insert(RateLimitEntry {
562            tokens: burst_capacity,
563            last_refill: now,
564            expires_at: now + Duration::from_secs(300), // 5 minutes
565        });
566
567        // Refill tokens based on time elapsed
568        let time_elapsed = now.duration_since(entry.last_refill);
569        let tokens_to_add = (time_elapsed.as_secs_f64() * f64::from(requests_per_second)) as u32;
570
571        if tokens_to_add > 0 {
572            entry.tokens = (entry.tokens + tokens_to_add).min(burst_capacity);
573            entry.last_refill = now;
574        }
575
576        if entry.tokens > 0 {
577            entry.tokens -= 1;
578            entry.expires_at = now + Duration::from_secs(300);
579            true
580        } else {
581            false
582        }
583    }
584}
585
586impl RateLimitMiddleware {
587    /// Create new rate limit middleware
588    #[must_use]
589    pub fn new(config: RateLimitConfig) -> Self {
590        let limiter = Arc::new(RateLimiter::new(
591            config.requests_per_second,
592            config.burst_capacity,
593        ));
594
595        Self { limiter, config }
596    }
597
598    /// Create new rate limit middleware for testing (no background tasks)
599    #[must_use]
600    #[cfg(test)]
601    pub fn new_for_testing(config: RateLimitConfig) -> Self {
602        let limiter = Arc::new(RateLimiter::new_for_testing(
603            config.requests_per_second,
604            config.burst_capacity,
605        ));
606
607        Self { limiter, config }
608    }
609}
610
611#[async_trait]
612impl Middleware for RateLimitMiddleware {
613    async fn process_request(
614        &self,
615        _request: &mut JsonRpcRequest,
616        ctx: &mut RequestContext,
617    ) -> ServerResult<()> {
618        let key = match &self.config.key_extractor {
619            KeyExtractor::ClientIp => ctx
620                .metadata
621                .get("client_ip")
622                .and_then(|v| v.as_str())
623                .unwrap_or("unknown")
624                .to_string(),
625            KeyExtractor::UserId => ctx
626                .metadata
627                .get("auth")
628                .and_then(|v| v.get("user_id"))
629                .and_then(|v| v.as_str())
630                .unwrap_or("anonymous")
631                .to_string(),
632            KeyExtractor::ApiKey => ctx
633                .metadata
634                .get("api_key")
635                .and_then(|v| v.as_str())
636                .unwrap_or("unknown")
637                .to_string(),
638            KeyExtractor::Custom(field) => ctx
639                .metadata
640                .get(field)
641                .and_then(|v| v.as_str())
642                .unwrap_or("unknown")
643                .to_string(),
644            KeyExtractor::Global => "global".to_string(),
645        };
646
647        let allowed = self
648            .limiter
649            .check_rate_limit(
650                &key,
651                self.config.requests_per_second,
652                self.config.burst_capacity,
653            )
654            .await;
655
656        if allowed {
657            Ok(())
658        } else {
659            Err(ServerError::rate_limit_with_retry(
660                format!("Rate limit exceeded for key: {key}"),
661                60, // Retry after 60 seconds
662            ))
663        }
664    }
665
666    async fn process_response(
667        &self,
668        _response: &mut JsonRpcResponse,
669        _ctx: &RequestContext,
670    ) -> ServerResult<()> {
671        Ok(())
672    }
673
674    fn name(&self) -> &'static str {
675        "rate_limit"
676    }
677
678    fn priority(&self) -> u32 {
679        20 // High priority, but after auth
680    }
681}
682
683/// Logging middleware for request/response logging
684#[derive(Debug)]
685pub struct LoggingMiddleware {
686    /// Logging configuration
687    config: LoggingConfig,
688}
689
690/// Logging configuration
691#[derive(Debug, Clone)]
692pub struct LoggingConfig {
693    /// Log request bodies
694    pub log_request_body: bool,
695    /// Log response bodies
696    pub log_response_body: bool,
697    /// Log timing information
698    pub log_timing: bool,
699    /// Maximum body size to log
700    pub max_body_size: usize,
701}
702
703impl Default for LoggingConfig {
704    fn default() -> Self {
705        Self {
706            log_request_body: false,
707            log_response_body: false,
708            log_timing: true,
709            max_body_size: 1024,
710        }
711    }
712}
713
714impl LoggingMiddleware {
715    /// Create new logging middleware
716    #[must_use]
717    pub fn new() -> Self {
718        Self {
719            config: LoggingConfig::default(),
720        }
721    }
722
723    /// Create with configuration
724    #[must_use]
725    pub const fn with_config(config: LoggingConfig) -> Self {
726        Self { config }
727    }
728}
729
730impl Default for LoggingMiddleware {
731    fn default() -> Self {
732        Self::new()
733    }
734}
735
736#[async_trait]
737impl Middleware for LoggingMiddleware {
738    async fn process_request(
739        &self,
740        request: &mut JsonRpcRequest,
741        ctx: &mut RequestContext,
742    ) -> ServerResult<()> {
743        // RequestContext already tracks start_time internally
744        let _start_time = ctx.start_time;
745
746        if self.config.log_request_body {
747            if let Ok(body) = serde_json::to_string(request) {
748                if body.len() <= self.config.max_body_size {
749                    tracing::info!(method = %request.method, body = %body, "Request received");
750                } else {
751                    tracing::info!(method = %request.method, body_size = body.len(), "Request received (body truncated)");
752                }
753            }
754        } else {
755            tracing::info!(method = %request.method, id = ?request.id, "Request received");
756        }
757
758        Ok(())
759    }
760
761    async fn process_response(
762        &self,
763        response: &mut JsonRpcResponse,
764        ctx: &RequestContext,
765    ) -> ServerResult<()> {
766        if self.config.log_timing {
767            // Calculate duration from start time
768            let duration = ctx.start_time.elapsed();
769            tracing::info!(
770                id = ?response.id,
771                has_error = response.error.is_some(),
772                duration_ms = duration.as_millis(),
773                "Request completed"
774            );
775        }
776
777        if self.config.log_response_body
778            && let Ok(body) = serde_json::to_string(response)
779        {
780            if body.len() <= self.config.max_body_size {
781                tracing::debug!(id = ?response.id, body = %body, "Response sent");
782            } else {
783                tracing::debug!(id = ?response.id, body_size = body.len(), "Response sent (body truncated)");
784            }
785        }
786
787        Ok(())
788    }
789
790    fn name(&self) -> &'static str {
791        "logging"
792    }
793
794    fn priority(&self) -> u32 {
795        1000 // Low priority - log everything
796    }
797}
798
799/// HTTP Security Headers middleware for defense-in-depth security
800#[derive(Debug, Clone)]
801pub struct SecurityHeadersMiddleware {
802    /// Security headers configuration
803    config: SecurityHeadersConfig,
804}
805
806/// Security headers configuration
807#[derive(Debug, Clone)]
808pub struct SecurityHeadersConfig {
809    /// Content Security Policy header
810    pub content_security_policy: Option<String>,
811    /// X-Frame-Options header
812    pub x_frame_options: Option<String>,
813    /// X-Content-Type-Options header
814    pub x_content_type_options: bool,
815    /// X-XSS-Protection header
816    pub x_xss_protection: Option<String>,
817    /// Strict-Transport-Security header
818    pub strict_transport_security: Option<String>,
819    /// Referrer-Policy header
820    pub referrer_policy: Option<String>,
821    /// Permissions-Policy header
822    pub permissions_policy: Option<String>,
823    /// Cross-Origin-Embedder-Policy header
824    pub cross_origin_embedder_policy: Option<String>,
825    /// Cross-Origin-Opener-Policy header
826    pub cross_origin_opener_policy: Option<String>,
827    /// Cross-Origin-Resource-Policy header
828    pub cross_origin_resource_policy: Option<String>,
829    /// Custom headers
830    pub custom_headers: HashMap<String, String>,
831}
832
833impl Default for SecurityHeadersConfig {
834    fn default() -> Self {
835        Self {
836            // Secure defaults for MCP servers
837            content_security_policy: Some(
838                "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; \
839                connect-src 'self'; img-src 'self' data:; font-src 'self'; object-src 'none'; \
840                media-src 'self'; frame-src 'none'; base-uri 'self'; form-action 'self'".to_string()
841            ),
842            x_frame_options: Some("DENY".to_string()),
843            x_content_type_options: true,
844            x_xss_protection: Some("1; mode=block".to_string()),
845            strict_transport_security: Some("max-age=31536000; includeSubDomains; preload".to_string()),
846            referrer_policy: Some("strict-origin-when-cross-origin".to_string()),
847            permissions_policy: Some(
848                "geolocation=(), microphone=(), camera=(), payment=(), usb=(), \
849                gyroscope=(), accelerometer=(), magnetometer=()".to_string()
850            ),
851            cross_origin_embedder_policy: Some("require-corp".to_string()),
852            cross_origin_opener_policy: Some("same-origin".to_string()),
853            cross_origin_resource_policy: Some("same-origin".to_string()),
854            custom_headers: HashMap::new(),
855        }
856    }
857}
858
859impl SecurityHeadersConfig {
860    /// Create a new security headers config
861    #[must_use]
862    pub fn new() -> Self {
863        Self::default()
864    }
865
866    /// Create a relaxed configuration for development
867    #[must_use]
868    pub fn relaxed() -> Self {
869        Self {
870            content_security_policy: Some(
871                "default-src 'self' 'unsafe-inline' 'unsafe-eval'".to_string(),
872            ),
873            x_frame_options: Some("SAMEORIGIN".to_string()),
874            x_content_type_options: true,
875            x_xss_protection: Some("1; mode=block".to_string()),
876            strict_transport_security: None, // Don't enforce HTTPS in dev
877            referrer_policy: Some("no-referrer-when-downgrade".to_string()),
878            permissions_policy: None,
879            cross_origin_embedder_policy: None,
880            cross_origin_opener_policy: None,
881            cross_origin_resource_policy: Some("cross-origin".to_string()),
882            custom_headers: HashMap::new(),
883        }
884    }
885
886    /// Create a strict configuration for production
887    #[must_use]
888    pub fn strict() -> Self {
889        Self {
890            content_security_policy: Some(
891                "default-src 'none'; script-src 'self'; style-src 'self'; \
892                connect-src 'self'; img-src 'self'; font-src 'self'; \
893                object-src 'none'; media-src 'none'; frame-src 'none'; \
894                base-uri 'none'; form-action 'none'"
895                    .to_string(),
896            ),
897            x_frame_options: Some("DENY".to_string()),
898            x_content_type_options: true,
899            x_xss_protection: Some("1; mode=block".to_string()),
900            strict_transport_security: Some(
901                "max-age=63072000; includeSubDomains; preload".to_string(),
902            ),
903            referrer_policy: Some("no-referrer".to_string()),
904            permissions_policy: Some(
905                "geolocation=(), microphone=(), camera=(), payment=(), usb=(), \
906                gyroscope=(), accelerometer=(), magnetometer=(), display-capture=(), \
907                screen-wake-lock=(), web-share=()"
908                    .to_string(),
909            ),
910            cross_origin_embedder_policy: Some("require-corp".to_string()),
911            cross_origin_opener_policy: Some("same-origin".to_string()),
912            cross_origin_resource_policy: Some("same-origin".to_string()),
913            custom_headers: HashMap::new(),
914        }
915    }
916
917    /// Add a custom header
918    #[must_use]
919    pub fn with_custom_header(mut self, name: String, value: String) -> Self {
920        self.custom_headers.insert(name, value);
921        self
922    }
923
924    /// Set Content Security Policy
925    #[must_use]
926    pub fn with_csp(mut self, csp: Option<String>) -> Self {
927        self.content_security_policy = csp;
928        self
929    }
930
931    /// Set Strict Transport Security
932    #[must_use]
933    pub fn with_hsts(mut self, hsts: Option<String>) -> Self {
934        self.strict_transport_security = hsts;
935        self
936    }
937}
938
939impl SecurityHeadersMiddleware {
940    /// Create new security headers middleware with default configuration
941    #[must_use]
942    pub fn new() -> Self {
943        Self {
944            config: SecurityHeadersConfig::default(),
945        }
946    }
947
948    /// Create with custom configuration
949    #[must_use]
950    pub const fn with_config(config: SecurityHeadersConfig) -> Self {
951        Self { config }
952    }
953
954    /// Create with relaxed configuration for development
955    #[must_use]
956    pub fn relaxed() -> Self {
957        Self {
958            config: SecurityHeadersConfig::relaxed(),
959        }
960    }
961
962    /// Create with strict configuration for production
963    #[must_use]
964    pub fn strict() -> Self {
965        Self {
966            config: SecurityHeadersConfig::strict(),
967        }
968    }
969}
970
971impl Default for SecurityHeadersMiddleware {
972    fn default() -> Self {
973        Self::new()
974    }
975}
976
977#[async_trait]
978impl Middleware for SecurityHeadersMiddleware {
979    async fn process_request(
980        &self,
981        _request: &mut JsonRpcRequest,
982        _ctx: &mut RequestContext,
983    ) -> ServerResult<()> {
984        // Security headers are applied on the response
985        Ok(())
986    }
987
988    async fn process_response(
989        &self,
990        response: &mut JsonRpcResponse,
991        ctx: &RequestContext,
992    ) -> ServerResult<()> {
993        // Add security headers to the response object itself
994        // The transport layer can read these headers and apply them to the HTTP response
995        let mut security_headers = HashMap::new();
996
997        // Content Security Policy
998        if let Some(csp) = &self.config.content_security_policy {
999            security_headers.insert("Content-Security-Policy".to_string(), csp.clone());
1000        }
1001
1002        // X-Frame-Options
1003        if let Some(xfo) = &self.config.x_frame_options {
1004            security_headers.insert("X-Frame-Options".to_string(), xfo.clone());
1005        }
1006
1007        // X-Content-Type-Options
1008        if self.config.x_content_type_options {
1009            security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
1010        }
1011
1012        // X-XSS-Protection
1013        if let Some(xss) = &self.config.x_xss_protection {
1014            security_headers.insert("X-XSS-Protection".to_string(), xss.clone());
1015        }
1016
1017        // Strict-Transport-Security
1018        if let Some(hsts) = &self.config.strict_transport_security {
1019            security_headers.insert("Strict-Transport-Security".to_string(), hsts.clone());
1020        }
1021
1022        // Referrer-Policy
1023        if let Some(rp) = &self.config.referrer_policy {
1024            security_headers.insert("Referrer-Policy".to_string(), rp.clone());
1025        }
1026
1027        // Permissions-Policy
1028        if let Some(pp) = &self.config.permissions_policy {
1029            security_headers.insert("Permissions-Policy".to_string(), pp.clone());
1030        }
1031
1032        // Cross-Origin-Embedder-Policy
1033        if let Some(coep) = &self.config.cross_origin_embedder_policy {
1034            security_headers.insert("Cross-Origin-Embedder-Policy".to_string(), coep.clone());
1035        }
1036
1037        // Cross-Origin-Opener-Policy
1038        if let Some(coop) = &self.config.cross_origin_opener_policy {
1039            security_headers.insert("Cross-Origin-Opener-Policy".to_string(), coop.clone());
1040        }
1041
1042        // Cross-Origin-Resource-Policy
1043        if let Some(corp) = &self.config.cross_origin_resource_policy {
1044            security_headers.insert("Cross-Origin-Resource-Policy".to_string(), corp.clone());
1045        }
1046
1047        // Custom headers
1048        for (name, value) in &self.config.custom_headers {
1049            security_headers.insert(name.clone(), value.clone());
1050        }
1051
1052        // Store security headers in the response for the transport layer to read
1053        // We add this as a special field that the transport can detect
1054        if let Some(result) = &mut response.result {
1055            if let Some(obj) = result.as_object_mut() {
1056                obj.insert(
1057                    "_security_headers".to_string(),
1058                    serde_json::to_value(&security_headers)?,
1059                );
1060            }
1061        } else {
1062            // If there's no result, add it as metadata
1063            response.result = Some(serde_json::json!({
1064                "_security_headers": security_headers
1065            }));
1066        }
1067
1068        tracing::debug!(
1069            request_id = %ctx.request_id,
1070            headers_count = security_headers.len(),
1071            "Applied security headers to response"
1072        );
1073
1074        Ok(())
1075    }
1076
1077    fn name(&self) -> &'static str {
1078        "security_headers"
1079    }
1080
1081    fn priority(&self) -> u32 {
1082        900 // Apply late in the response pipeline, but before logging
1083    }
1084}
1085
1086/// Middleware layer for easier composition
1087pub type MiddlewareLayer = Arc<dyn Middleware>;
1088
1089fn start_ts() -> u64 {
1090    use std::time::{SystemTime, UNIX_EPOCH};
1091    SystemTime::now()
1092        .duration_since(UNIX_EPOCH)
1093        .map(|d| d.as_nanos() as u64)
1094        .unwrap_or(0)
1095}