1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use turbomcp_core::RequestContext;
9use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse};
10
11use crate::{ServerError, ServerResult};
12
13#[async_trait]
15pub trait Middleware: Send + Sync {
16 async fn process_request(
18 &self,
19 request: &mut JsonRpcRequest,
20 ctx: &mut RequestContext,
21 ) -> ServerResult<()>;
22
23 async fn process_response(
25 &self,
26 response: &mut JsonRpcResponse,
27 ctx: &RequestContext,
28 ) -> ServerResult<()>;
29
30 fn name(&self) -> &str;
32
33 fn priority(&self) -> u32 {
35 100
36 }
37
38 fn enabled(&self) -> bool {
40 true
41 }
42}
43
44pub struct MiddlewareStack {
46 middleware: Vec<Arc<dyn Middleware>>,
48 config: StackConfig,
50}
51
52impl std::fmt::Debug for MiddlewareStack {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("MiddlewareStack")
55 .field("middleware_count", &self.middleware.len())
56 .field("config", &self.config)
57 .finish()
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct StackConfig {
64 pub enable_metrics: bool,
66 pub enable_tracing: bool,
68 pub timeout_ms: u64,
70 pub enable_recovery: bool,
72}
73
74impl Default for StackConfig {
75 fn default() -> Self {
76 Self {
77 enable_metrics: true,
78 enable_tracing: true,
79 timeout_ms: 5_000,
80 enable_recovery: true,
81 }
82 }
83}
84
85impl MiddlewareStack {
86 #[must_use]
88 pub fn new() -> Self {
89 Self {
90 middleware: Vec::new(),
91 config: StackConfig::default(),
92 }
93 }
94
95 #[must_use]
97 pub fn with_config(config: StackConfig) -> Self {
98 Self {
99 middleware: Vec::new(),
100 config,
101 }
102 }
103
104 pub fn add<M>(&mut self, middleware: M)
106 where
107 M: Middleware + 'static,
108 {
109 self.middleware.push(Arc::new(middleware));
110 self.sort_by_priority();
111 }
112
113 pub fn remove(&mut self, name: &str) {
115 self.middleware.retain(|m| m.name() != name);
116 }
117
118 pub async fn process_request(
120 &self,
121 mut request: JsonRpcRequest,
122 mut ctx: RequestContext,
123 ) -> ServerResult<(JsonRpcRequest, RequestContext)> {
124 let global_start = Instant::now();
126 for middleware in &self.middleware {
127 if !middleware.enabled() {
128 continue;
129 }
130
131 let start = Instant::now();
132
133 let result = if self.config.timeout_ms > 0 {
135 tokio::time::timeout(
136 Duration::from_millis(self.config.timeout_ms),
137 middleware.process_request(&mut request, &mut ctx),
138 )
139 .await
140 } else {
141 Ok(middleware.process_request(&mut request, &mut ctx).await)
142 };
143
144 let duration = start.elapsed();
145
146 if self.config.enable_tracing {
147 tracing::debug!(
148 middleware = middleware.name(),
149 duration_ms = duration.as_millis(),
150 "Processed request through middleware"
151 );
152 }
153
154 match result {
155 Ok(Ok(())) => continue,
156 Ok(Err(e)) => {
157 if self.config.enable_recovery {
158 tracing::warn!(
159 middleware = middleware.name(),
160 error = %e,
161 "Middleware error, continuing with recovery"
162 );
163 continue;
164 }
165 return Err(ServerError::middleware(middleware.name(), e.to_string()));
166 }
167 Err(_) => {
168 let _error = format!(
169 "Middleware '{}' timed out after {}ms",
170 middleware.name(),
171 self.config.timeout_ms
172 );
173 if self.config.enable_recovery {
174 tracing::warn!(
175 middleware = middleware.name(),
176 "Middleware timeout, continuing"
177 );
178 continue;
179 }
180 return Err(ServerError::timeout("middleware", self.config.timeout_ms));
181 }
182 }
183 }
184
185 let correlation_id = ctx
187 .metadata
188 .get("correlation_id")
189 .and_then(|v| v.as_str())
190 .map_or_else(
191 || uuid::Uuid::new_v4().to_string(),
192 std::string::ToString::to_string,
193 );
194 ctx = ctx.with_metadata("correlation_id", correlation_id);
195
196 let start_ns = start_ts();
198 let request_id = ctx.request_id.clone();
199 ctx = ctx.with_metadata("request_start_ns", start_ns);
200 ctx = ctx.with_metadata("request_id", request_id);
201 ctx = ctx.with_metadata(
203 "middleware_time_ms",
204 global_start.elapsed().as_millis() as u64,
205 );
206 Ok((request, ctx))
207 }
208
209 pub async fn process_response(
211 &self,
212 mut response: JsonRpcResponse,
213 ctx: &RequestContext,
214 ) -> ServerResult<JsonRpcResponse> {
215 for middleware in self.middleware.iter().rev() {
216 if !middleware.enabled() {
217 continue;
218 }
219
220 let start = Instant::now();
221
222 let result = if self.config.timeout_ms > 0 {
224 tokio::time::timeout(
225 Duration::from_millis(self.config.timeout_ms),
226 middleware.process_response(&mut response, ctx),
227 )
228 .await
229 } else {
230 Ok(middleware.process_response(&mut response, ctx).await)
231 };
232
233 let duration = start.elapsed();
234
235 if self.config.enable_tracing {
236 tracing::debug!(
237 middleware = middleware.name(),
238 duration_ms = duration.as_millis(),
239 "Processed response through middleware"
240 );
241 }
242
243 match result {
244 Ok(Ok(())) => continue,
245 Ok(Err(e)) => {
246 if self.config.enable_recovery {
247 tracing::warn!(
248 middleware = middleware.name(),
249 error = %e,
250 "Middleware error in response processing, continuing"
251 );
252 continue;
253 }
254 return Err(ServerError::middleware(middleware.name(), e.to_string()));
255 }
256 Err(_) => {
257 if self.config.enable_recovery {
258 tracing::warn!(
259 middleware = middleware.name(),
260 "Middleware timeout in response processing, continuing"
261 );
262 continue;
263 }
264 return Err(ServerError::timeout("middleware", self.config.timeout_ms));
265 }
266 }
267 }
268
269 if let Some(ns) = ctx
271 .metadata
272 .get("request_start_ns")
273 .and_then(serde_json::Value::as_u64)
274 {
275 let end_ns = start_ts();
276 let elapsed_ns = end_ns.saturating_sub(ns);
277 let latency_ms = (elapsed_ns as f64) / 1_000_000.0;
278 tracing::debug!(
279 correlation_id = ctx.metadata.get("correlation_id").and_then(|v| v.as_str()),
280 request_id = %ctx.request_id,
281 latency_ms,
282 "Request completed with latency"
283 );
284 }
285 Ok(response)
286 }
287
288 #[must_use]
290 pub fn len(&self) -> usize {
291 self.middleware.len()
292 }
293
294 #[must_use]
296 pub fn is_empty(&self) -> bool {
297 self.middleware.is_empty()
298 }
299
300 #[must_use]
302 pub fn list_middleware(&self) -> Vec<&str> {
303 self.middleware.iter().map(|m| m.name()).collect()
304 }
305
306 fn sort_by_priority(&mut self) {
307 self.middleware.sort_by_key(|m| m.priority());
308 }
309}
310
311impl Default for MiddlewareStack {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317pub struct AuthenticationMiddleware {
319 provider: Arc<dyn AuthProvider>,
321 config: AuthConfig,
323}
324
325impl std::fmt::Debug for AuthenticationMiddleware {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 f.debug_struct("AuthenticationMiddleware")
328 .field("config", &self.config)
329 .finish()
330 }
331}
332
333#[derive(Debug, Clone)]
335pub struct AuthConfig {
336 pub skip_methods: Vec<String>,
338 pub scheme: AuthScheme,
340 pub token_expiry: Duration,
342}
343
344#[derive(Debug, Clone)]
346pub enum AuthScheme {
347 Bearer,
349 ApiKey,
351 Basic,
353 Custom(String),
355}
356
357#[async_trait]
359pub trait AuthProvider: Send + Sync {
360 async fn authenticate(&self, request: &JsonRpcRequest) -> ServerResult<AuthContext>;
362
363 async fn validate_token(&self, token: &str) -> ServerResult<AuthContext>;
365}
366
367#[derive(Debug, Clone)]
369pub struct AuthContext {
370 pub user_id: String,
372 pub roles: Vec<String>,
374 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
376 pub claims: HashMap<String, serde_json::Value>,
378}
379
380impl AuthenticationMiddleware {
381 pub fn new<P>(provider: P) -> Self
383 where
384 P: AuthProvider + 'static,
385 {
386 Self {
387 provider: Arc::new(provider),
388 config: AuthConfig {
389 skip_methods: vec!["initialize".to_string()],
390 scheme: AuthScheme::Bearer,
391 token_expiry: Duration::from_secs(3600),
392 },
393 }
394 }
395
396 pub fn with_config<P>(provider: P, config: AuthConfig) -> Self
398 where
399 P: AuthProvider + 'static,
400 {
401 Self {
402 provider: Arc::new(provider),
403 config,
404 }
405 }
406}
407
408#[async_trait]
409impl Middleware for AuthenticationMiddleware {
410 async fn process_request(
411 &self,
412 request: &mut JsonRpcRequest,
413 _ctx: &mut RequestContext,
414 ) -> ServerResult<()> {
415 if self.config.skip_methods.contains(&request.method) {
417 return Ok(());
418 }
419
420 match self.provider.authenticate(request).await {
421 Ok(auth_ctx) => {
422 _ctx.user_id = Some(auth_ctx.user_id.clone());
424 let meta = std::sync::Arc::make_mut(&mut _ctx.metadata);
425 meta.insert("authenticated".to_string(), serde_json::json!(true));
426 meta.insert(
427 "auth".to_string(),
428 serde_json::json!({
429 "user_id": auth_ctx.user_id,
430 "roles": auth_ctx.roles,
431 "expires_at": auth_ctx.expires_at.map(|t| t.to_rfc3339()),
432 "claims": auth_ctx.claims,
433 }),
434 );
435 Ok(())
436 }
437 Err(e) => Err(ServerError::authentication(format!(
438 "Authentication failed: {e}"
439 ))),
440 }
441 }
442
443 async fn process_response(
444 &self,
445 _response: &mut JsonRpcResponse,
446 _ctx: &RequestContext,
447 ) -> ServerResult<()> {
448 Ok(())
449 }
450
451 fn name(&self) -> &'static str {
452 "authentication"
453 }
454
455 fn priority(&self) -> u32 {
456 10 }
458}
459
460#[derive(Debug)]
462pub struct RateLimitMiddleware {
463 limiter: Arc<RateLimiter>,
465 config: RateLimitConfig,
467}
468
469#[derive(Debug, Clone)]
471pub struct RateLimitConfig {
472 pub requests_per_second: u32,
474 pub burst_capacity: u32,
476 pub key_extractor: KeyExtractor,
478}
479
480#[derive(Debug, Clone)]
482pub enum KeyExtractor {
483 ClientIp,
485 UserId,
487 ApiKey,
489 Custom(String),
491 Global,
493}
494
495#[derive(Debug)]
497pub struct RateLimiter {
498 entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
500 _cleanup_handle: Option<tokio::task::JoinHandle<()>>,
502}
503
504#[derive(Debug, Clone)]
506struct RateLimitEntry {
507 tokens: u32,
509 last_refill: Instant,
511 expires_at: Instant,
513}
514
515impl RateLimiter {
516 #[must_use]
518 pub fn new(_requests_per_second: u32, _burst_capacity: u32) -> Self {
519 let entries = Arc::new(RwLock::new(HashMap::<String, RateLimitEntry>::new()));
520
521 let cleanup_entries = Arc::clone(&entries);
523 let cleanup_handle = tokio::spawn(async move {
524 let mut interval = tokio::time::interval(Duration::from_secs(60));
525 loop {
526 interval.tick().await;
527 let now = Instant::now();
528 let mut entries = cleanup_entries.write().await;
529 entries.retain(|_, entry| entry.expires_at > now);
530 }
531 });
532
533 Self {
534 entries,
535 _cleanup_handle: Some(cleanup_handle),
536 }
537 }
538
539 #[must_use]
541 #[cfg(test)]
542 pub fn new_for_testing(_requests_per_second: u32, _burst_capacity: u32) -> Self {
543 let entries = Arc::new(RwLock::new(HashMap::<String, RateLimitEntry>::new()));
544
545 Self {
546 entries,
547 _cleanup_handle: None, }
549 }
550
551 pub async fn check_rate_limit(
553 &self,
554 key: &str,
555 requests_per_second: u32,
556 burst_capacity: u32,
557 ) -> bool {
558 let mut entries = self.entries.write().await;
559 let now = Instant::now();
560
561 let entry = entries.entry(key.to_string()).or_insert(RateLimitEntry {
562 tokens: burst_capacity,
563 last_refill: now,
564 expires_at: now + Duration::from_secs(300), });
566
567 let time_elapsed = now.duration_since(entry.last_refill);
569 let tokens_to_add = (time_elapsed.as_secs_f64() * f64::from(requests_per_second)) as u32;
570
571 if tokens_to_add > 0 {
572 entry.tokens = (entry.tokens + tokens_to_add).min(burst_capacity);
573 entry.last_refill = now;
574 }
575
576 if entry.tokens > 0 {
577 entry.tokens -= 1;
578 entry.expires_at = now + Duration::from_secs(300);
579 true
580 } else {
581 false
582 }
583 }
584}
585
586impl RateLimitMiddleware {
587 #[must_use]
589 pub fn new(config: RateLimitConfig) -> Self {
590 let limiter = Arc::new(RateLimiter::new(
591 config.requests_per_second,
592 config.burst_capacity,
593 ));
594
595 Self { limiter, config }
596 }
597
598 #[must_use]
600 #[cfg(test)]
601 pub fn new_for_testing(config: RateLimitConfig) -> Self {
602 let limiter = Arc::new(RateLimiter::new_for_testing(
603 config.requests_per_second,
604 config.burst_capacity,
605 ));
606
607 Self { limiter, config }
608 }
609}
610
611#[async_trait]
612impl Middleware for RateLimitMiddleware {
613 async fn process_request(
614 &self,
615 _request: &mut JsonRpcRequest,
616 ctx: &mut RequestContext,
617 ) -> ServerResult<()> {
618 let key = match &self.config.key_extractor {
619 KeyExtractor::ClientIp => ctx
620 .metadata
621 .get("client_ip")
622 .and_then(|v| v.as_str())
623 .unwrap_or("unknown")
624 .to_string(),
625 KeyExtractor::UserId => ctx
626 .metadata
627 .get("auth")
628 .and_then(|v| v.get("user_id"))
629 .and_then(|v| v.as_str())
630 .unwrap_or("anonymous")
631 .to_string(),
632 KeyExtractor::ApiKey => ctx
633 .metadata
634 .get("api_key")
635 .and_then(|v| v.as_str())
636 .unwrap_or("unknown")
637 .to_string(),
638 KeyExtractor::Custom(field) => ctx
639 .metadata
640 .get(field)
641 .and_then(|v| v.as_str())
642 .unwrap_or("unknown")
643 .to_string(),
644 KeyExtractor::Global => "global".to_string(),
645 };
646
647 let allowed = self
648 .limiter
649 .check_rate_limit(
650 &key,
651 self.config.requests_per_second,
652 self.config.burst_capacity,
653 )
654 .await;
655
656 if allowed {
657 Ok(())
658 } else {
659 Err(ServerError::rate_limit_with_retry(
660 format!("Rate limit exceeded for key: {key}"),
661 60, ))
663 }
664 }
665
666 async fn process_response(
667 &self,
668 _response: &mut JsonRpcResponse,
669 _ctx: &RequestContext,
670 ) -> ServerResult<()> {
671 Ok(())
672 }
673
674 fn name(&self) -> &'static str {
675 "rate_limit"
676 }
677
678 fn priority(&self) -> u32 {
679 20 }
681}
682
683#[derive(Debug)]
685pub struct LoggingMiddleware {
686 config: LoggingConfig,
688}
689
690#[derive(Debug, Clone)]
692pub struct LoggingConfig {
693 pub log_request_body: bool,
695 pub log_response_body: bool,
697 pub log_timing: bool,
699 pub max_body_size: usize,
701}
702
703impl Default for LoggingConfig {
704 fn default() -> Self {
705 Self {
706 log_request_body: false,
707 log_response_body: false,
708 log_timing: true,
709 max_body_size: 1024,
710 }
711 }
712}
713
714impl LoggingMiddleware {
715 #[must_use]
717 pub fn new() -> Self {
718 Self {
719 config: LoggingConfig::default(),
720 }
721 }
722
723 #[must_use]
725 pub const fn with_config(config: LoggingConfig) -> Self {
726 Self { config }
727 }
728}
729
730impl Default for LoggingMiddleware {
731 fn default() -> Self {
732 Self::new()
733 }
734}
735
736#[async_trait]
737impl Middleware for LoggingMiddleware {
738 async fn process_request(
739 &self,
740 request: &mut JsonRpcRequest,
741 ctx: &mut RequestContext,
742 ) -> ServerResult<()> {
743 let _start_time = ctx.start_time;
745
746 if self.config.log_request_body {
747 if let Ok(body) = serde_json::to_string(request) {
748 if body.len() <= self.config.max_body_size {
749 tracing::info!(method = %request.method, body = %body, "Request received");
750 } else {
751 tracing::info!(method = %request.method, body_size = body.len(), "Request received (body truncated)");
752 }
753 }
754 } else {
755 tracing::info!(method = %request.method, id = ?request.id, "Request received");
756 }
757
758 Ok(())
759 }
760
761 async fn process_response(
762 &self,
763 response: &mut JsonRpcResponse,
764 ctx: &RequestContext,
765 ) -> ServerResult<()> {
766 if self.config.log_timing {
767 let duration = ctx.start_time.elapsed();
769 tracing::info!(
770 id = ?response.id,
771 has_error = response.error.is_some(),
772 duration_ms = duration.as_millis(),
773 "Request completed"
774 );
775 }
776
777 if self.config.log_response_body
778 && let Ok(body) = serde_json::to_string(response)
779 {
780 if body.len() <= self.config.max_body_size {
781 tracing::debug!(id = ?response.id, body = %body, "Response sent");
782 } else {
783 tracing::debug!(id = ?response.id, body_size = body.len(), "Response sent (body truncated)");
784 }
785 }
786
787 Ok(())
788 }
789
790 fn name(&self) -> &'static str {
791 "logging"
792 }
793
794 fn priority(&self) -> u32 {
795 1000 }
797}
798
799#[derive(Debug, Clone)]
801pub struct SecurityHeadersMiddleware {
802 config: SecurityHeadersConfig,
804}
805
806#[derive(Debug, Clone)]
808pub struct SecurityHeadersConfig {
809 pub content_security_policy: Option<String>,
811 pub x_frame_options: Option<String>,
813 pub x_content_type_options: bool,
815 pub x_xss_protection: Option<String>,
817 pub strict_transport_security: Option<String>,
819 pub referrer_policy: Option<String>,
821 pub permissions_policy: Option<String>,
823 pub cross_origin_embedder_policy: Option<String>,
825 pub cross_origin_opener_policy: Option<String>,
827 pub cross_origin_resource_policy: Option<String>,
829 pub custom_headers: HashMap<String, String>,
831}
832
833impl Default for SecurityHeadersConfig {
834 fn default() -> Self {
835 Self {
836 content_security_policy: Some(
838 "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; \
839 connect-src 'self'; img-src 'self' data:; font-src 'self'; object-src 'none'; \
840 media-src 'self'; frame-src 'none'; base-uri 'self'; form-action 'self'".to_string()
841 ),
842 x_frame_options: Some("DENY".to_string()),
843 x_content_type_options: true,
844 x_xss_protection: Some("1; mode=block".to_string()),
845 strict_transport_security: Some("max-age=31536000; includeSubDomains; preload".to_string()),
846 referrer_policy: Some("strict-origin-when-cross-origin".to_string()),
847 permissions_policy: Some(
848 "geolocation=(), microphone=(), camera=(), payment=(), usb=(), \
849 gyroscope=(), accelerometer=(), magnetometer=()".to_string()
850 ),
851 cross_origin_embedder_policy: Some("require-corp".to_string()),
852 cross_origin_opener_policy: Some("same-origin".to_string()),
853 cross_origin_resource_policy: Some("same-origin".to_string()),
854 custom_headers: HashMap::new(),
855 }
856 }
857}
858
859impl SecurityHeadersConfig {
860 #[must_use]
862 pub fn new() -> Self {
863 Self::default()
864 }
865
866 #[must_use]
868 pub fn relaxed() -> Self {
869 Self {
870 content_security_policy: Some(
871 "default-src 'self' 'unsafe-inline' 'unsafe-eval'".to_string(),
872 ),
873 x_frame_options: Some("SAMEORIGIN".to_string()),
874 x_content_type_options: true,
875 x_xss_protection: Some("1; mode=block".to_string()),
876 strict_transport_security: None, referrer_policy: Some("no-referrer-when-downgrade".to_string()),
878 permissions_policy: None,
879 cross_origin_embedder_policy: None,
880 cross_origin_opener_policy: None,
881 cross_origin_resource_policy: Some("cross-origin".to_string()),
882 custom_headers: HashMap::new(),
883 }
884 }
885
886 #[must_use]
888 pub fn strict() -> Self {
889 Self {
890 content_security_policy: Some(
891 "default-src 'none'; script-src 'self'; style-src 'self'; \
892 connect-src 'self'; img-src 'self'; font-src 'self'; \
893 object-src 'none'; media-src 'none'; frame-src 'none'; \
894 base-uri 'none'; form-action 'none'"
895 .to_string(),
896 ),
897 x_frame_options: Some("DENY".to_string()),
898 x_content_type_options: true,
899 x_xss_protection: Some("1; mode=block".to_string()),
900 strict_transport_security: Some(
901 "max-age=63072000; includeSubDomains; preload".to_string(),
902 ),
903 referrer_policy: Some("no-referrer".to_string()),
904 permissions_policy: Some(
905 "geolocation=(), microphone=(), camera=(), payment=(), usb=(), \
906 gyroscope=(), accelerometer=(), magnetometer=(), display-capture=(), \
907 screen-wake-lock=(), web-share=()"
908 .to_string(),
909 ),
910 cross_origin_embedder_policy: Some("require-corp".to_string()),
911 cross_origin_opener_policy: Some("same-origin".to_string()),
912 cross_origin_resource_policy: Some("same-origin".to_string()),
913 custom_headers: HashMap::new(),
914 }
915 }
916
917 #[must_use]
919 pub fn with_custom_header(mut self, name: String, value: String) -> Self {
920 self.custom_headers.insert(name, value);
921 self
922 }
923
924 #[must_use]
926 pub fn with_csp(mut self, csp: Option<String>) -> Self {
927 self.content_security_policy = csp;
928 self
929 }
930
931 #[must_use]
933 pub fn with_hsts(mut self, hsts: Option<String>) -> Self {
934 self.strict_transport_security = hsts;
935 self
936 }
937}
938
939impl SecurityHeadersMiddleware {
940 #[must_use]
942 pub fn new() -> Self {
943 Self {
944 config: SecurityHeadersConfig::default(),
945 }
946 }
947
948 #[must_use]
950 pub const fn with_config(config: SecurityHeadersConfig) -> Self {
951 Self { config }
952 }
953
954 #[must_use]
956 pub fn relaxed() -> Self {
957 Self {
958 config: SecurityHeadersConfig::relaxed(),
959 }
960 }
961
962 #[must_use]
964 pub fn strict() -> Self {
965 Self {
966 config: SecurityHeadersConfig::strict(),
967 }
968 }
969}
970
971impl Default for SecurityHeadersMiddleware {
972 fn default() -> Self {
973 Self::new()
974 }
975}
976
977#[async_trait]
978impl Middleware for SecurityHeadersMiddleware {
979 async fn process_request(
980 &self,
981 _request: &mut JsonRpcRequest,
982 _ctx: &mut RequestContext,
983 ) -> ServerResult<()> {
984 Ok(())
986 }
987
988 async fn process_response(
989 &self,
990 response: &mut JsonRpcResponse,
991 ctx: &RequestContext,
992 ) -> ServerResult<()> {
993 let mut security_headers = HashMap::new();
996
997 if let Some(csp) = &self.config.content_security_policy {
999 security_headers.insert("Content-Security-Policy".to_string(), csp.clone());
1000 }
1001
1002 if let Some(xfo) = &self.config.x_frame_options {
1004 security_headers.insert("X-Frame-Options".to_string(), xfo.clone());
1005 }
1006
1007 if self.config.x_content_type_options {
1009 security_headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
1010 }
1011
1012 if let Some(xss) = &self.config.x_xss_protection {
1014 security_headers.insert("X-XSS-Protection".to_string(), xss.clone());
1015 }
1016
1017 if let Some(hsts) = &self.config.strict_transport_security {
1019 security_headers.insert("Strict-Transport-Security".to_string(), hsts.clone());
1020 }
1021
1022 if let Some(rp) = &self.config.referrer_policy {
1024 security_headers.insert("Referrer-Policy".to_string(), rp.clone());
1025 }
1026
1027 if let Some(pp) = &self.config.permissions_policy {
1029 security_headers.insert("Permissions-Policy".to_string(), pp.clone());
1030 }
1031
1032 if let Some(coep) = &self.config.cross_origin_embedder_policy {
1034 security_headers.insert("Cross-Origin-Embedder-Policy".to_string(), coep.clone());
1035 }
1036
1037 if let Some(coop) = &self.config.cross_origin_opener_policy {
1039 security_headers.insert("Cross-Origin-Opener-Policy".to_string(), coop.clone());
1040 }
1041
1042 if let Some(corp) = &self.config.cross_origin_resource_policy {
1044 security_headers.insert("Cross-Origin-Resource-Policy".to_string(), corp.clone());
1045 }
1046
1047 for (name, value) in &self.config.custom_headers {
1049 security_headers.insert(name.clone(), value.clone());
1050 }
1051
1052 if let Some(result) = &mut response.result {
1055 if let Some(obj) = result.as_object_mut() {
1056 obj.insert(
1057 "_security_headers".to_string(),
1058 serde_json::to_value(&security_headers)?,
1059 );
1060 }
1061 } else {
1062 response.result = Some(serde_json::json!({
1064 "_security_headers": security_headers
1065 }));
1066 }
1067
1068 tracing::debug!(
1069 request_id = %ctx.request_id,
1070 headers_count = security_headers.len(),
1071 "Applied security headers to response"
1072 );
1073
1074 Ok(())
1075 }
1076
1077 fn name(&self) -> &'static str {
1078 "security_headers"
1079 }
1080
1081 fn priority(&self) -> u32 {
1082 900 }
1084}
1085
1086pub type MiddlewareLayer = Arc<dyn Middleware>;
1088
1089fn start_ts() -> u64 {
1090 use std::time::{SystemTime, UNIX_EPOCH};
1091 SystemTime::now()
1092 .duration_since(UNIX_EPOCH)
1093 .map(|d| d.as_nanos() as u64)
1094 .unwrap_or(0)
1095}